From 46969ff746428eb7598787c401fe2c95335dab0c Mon Sep 17 00:00:00 2001 From: Nicholas Geneva Date: Mon, 13 Apr 2026 22:04:06 +0000 Subject: [PATCH 1/7] feat: add create and validate diagnostic model wrapper skills Add two Claude skills for diagnostic model (dx) development: - create-diagnostic-wrapper: Steps 0-12 for wrapping both NN-based and physics-based diagnostic models with CONFIRM gates - validate-diagnostic-wrapper: Steps 1-4 for tests, reference comparison, sanity-check plots, PR submission, and Greptile review --- .../skills/create-diagnostic-wrapper/SKILL.md | 1440 +++++++++++++++++ .../validate-diagnostic-wrapper/SKILL.md | 831 ++++++++++ 2 files changed, 2271 insertions(+) create mode 100644 .claude/skills/create-diagnostic-wrapper/SKILL.md create mode 100644 .claude/skills/validate-diagnostic-wrapper/SKILL.md diff --git a/.claude/skills/create-diagnostic-wrapper/SKILL.md b/.claude/skills/create-diagnostic-wrapper/SKILL.md new file mode 100644 index 000000000..9df944198 --- /dev/null +++ b/.claude/skills/create-diagnostic-wrapper/SKILL.md @@ -0,0 +1,1440 @@ +--- +name: create-diagnostic-wrapper +description: Create a new Earth2Studio diagnostic model (dx) wrapper from a reference inference script or repository. Handles both NN-based models (with AutoModelMixin and Package loading) and physics-based calculators (pure torch computations). +argument-hint: URL or local path to reference inference script/repo (optional — will be asked if not provided) +--- + +# Create Diagnostic Model Wrapper + +Create a new Earth2Studio diagnostic model wrapper by following every step below in order. +Each confirmation gate marked by starting with: + +```markdown +### **[CONFIRM — ]** +``` + +requires explicit user approval before proceeding. + +> **Environment note**: Use `uv run python` for all Python +> commands. The project uses a `uv`-managed virtual environment +> — never install packages globally or use bare `python`. + +--- + +## Step 0 — Obtain Reference & Detect Model Type + +### 0a. Obtain reference script / repository + +If `$ARGUMENTS` is provided, use it as the reference inference script or repository. + +- If it is a URL, use WebFetch to retrieve the content. +- If it is a local file path, read it directly. + +If `$ARGUMENTS` is empty or not provided, ask the user: + +> Please provide a reference inference script or +> repository URL/path that demonstrates how this model +> runs inference. This will be used to understand the +> model architecture, dependencies, input/output shapes, +> and variable mapping. + +Store the reference code content for use in subsequent steps. + +### 0b. Detect model type — NN-based vs physics-based + +After obtaining the reference, classify the model into one +of two categories: + +**NN-based indicators** (checkpoint-backed neural network): + +- Checkpoint loading (`torch.load`, `.from_pretrained`, ONNX runtime) +- Trained weights, normalization parameters (mean/std tensors) +- Model architecture classes (e.g., `torch.nn.Module` subclass with learned layers) +- Imports like `onnxruntime`, `timm`, `einops`, `physicsnemo` +- `.pt`, `.onnx`, `.safetensors`, `.mdlus` checkpoint files + +**Physics-based indicators** (analytical / derived calculator): + +- Analytical formulas (e.g., `sqrt(u² + v²)`, Clausius–Clapeyron) +- No trained weights or checkpoint files +- Parameterized by physical constants, pressure levels, or thresholds +- Pure `torch` math operations +- No model architecture classes + +Present the detection result to the user: + +> **Model type detected: [NN-based / Physics-based]** +> +> Evidence: +> +> - [list key indicators found] +> +> This determines branching in Steps 2, 3, and 6. +> NN-based models require dependency management and +> checkpoint loading. Physics-based models skip those +> steps. + +Store the model type flag (`nn` or `physics`) for +conditional branching in Steps 2, 3, and 6. + +### **[CONFIRM — Model Type]** + +Ask the user to confirm the detected model type before +proceeding. + +--- + +## Step 1 — Examine Reference & Propose Dependencies + +### 1a. Analyze the reference code + +Examine the reference inference script/repo to identify: + +- **Python packages** required (e.g., `torch`, `onnxruntime`, `einops`, `timm`, custom packages) +- **Model architecture** (PyTorch, ONNX, JAX, etc.) — NN-based only +- **Input/output tensor shapes** and variable names +- **Spatial resolution** (lat/lon grid dimensions and spacing, or flexible for physics-based) +- **For NN-based**: Checkpoint format (`.pt`, `.onnx`, `.safetensors`, etc.) +- **For physics-based**: Input variable patterns (e.g., `[u{level}, v{level}]`) + and output variable formulas (e.g., `ws = sqrt(u² + v²)`) + +### 1b. Propose pyproject.toml dependency group + +**NN-based models only:** + +Propose a new optional dependency group for `pyproject.toml`. Follow the existing pattern: + +```toml +# In [project.optional-dependencies] section of pyproject.toml +model-name = [ + "package1>=version", + "package2", +] +``` + +The group name should be lowercase-hyphenated (e.g., `precip-afno`, `windgust-afno`, `climatenet`). + +Look at the existing groups in `pyproject.toml` +(lines ~59-257) for reference on naming and version +pinning conventions. + +**Also propose adding the new group to the `all` aggregate** in the appropriate line (dx models line). + +**Physics-based models:** + +State: "No external dependencies needed — physics-based model uses +only `torch` and `numpy` (already core dependencies)." + +### **[CONFIRM — Dependencies]** + +Present to the user: + +1. The proposed dependency group name (NN-based) or + "No dependencies" (physics-based) +2. The list of packages with versions (NN-based only) +3. Ask if the packages and group name look correct + +--- + +## Step 2 — Add Dependencies to pyproject.toml + +**NN-based models:** + +After confirmation, edit `pyproject.toml`: + +1. Add the new optional dependency group in alphabetical order + among the per-model extras +2. Add the group to the `all` aggregate (in the dx models line) + +**Physics-based models:** + +Skip — no external dependencies for physics-based model. +State explicitly: "Step 2 skipped — physics-based model +has no external dependencies." + +--- + +## Step 3 — Create Skeleton Class File + +### 3a. Determine class name and file name + +Based on the model name from the reference, propose: + +- **Class name**: PascalCase (e.g., `PrecipitationAFNO`, `DerivedWS`, `WindgustAFNO`) +- **File name**: lowercase with underscores + (e.g., `precipitation_afno.py`, `derived.py`, `wind_gust.py`) +- **File path**: `earth2studio/models/dx/<filename>.py` + +### 3b. Write skeleton with pseudocode + +Create the file with the full structure but pseudocode +implementations. Every `.py` file in `earth2studio/` +**must** start with this license header: + +```python +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +``` + +Use one of the two skeleton templates depending on the +model type detected in Step 0. + +#### NN-based skeleton (dual inheritance: `torch.nn.Module, AutoModelMixin`) + +```python +from collections import OrderedDict + +import numpy as np +import torch + +from earth2studio.models.auto import AutoModelMixin, Package +from earth2studio.models.batch import batch_coords, batch_func +from earth2studio.models.dx.base import DiagnosticModel +from earth2studio.utils import handshake_coords, handshake_dim +from earth2studio.utils.imports import ( + OptionalDependencyFailure, + check_optional_dependencies, +) +from earth2studio.utils.type import CoordSystem + +# Optional dependency imports (try/except pattern) +try: + import optional_package +except ImportError: + OptionalDependencyFailure("model-name") + optional_package = None + +VARIABLES = [...] # List of input variable names from E2STUDIO_VOCAB + +@check_optional_dependencies() +class ModelName(torch.nn.Module, AutoModelMixin): + """One-line description. + + Extended description of the model, its source, + and any relevant details. + + Parameters + ---------- + core_model : torch.nn.Module + Core neural network model + ...additional params... + + Note + ---- + For more information see: <link to paper/repo> + """ + + # 1. Constructor + def __init__(self, core_model: torch.nn.Module, ...): + super().__init__() + self.core_model = core_model + self.register_buffer( + "device_buffer", torch.empty(0) + ) + # TODO: Register normalization buffers (center, scale, etc.) + + # 2. Input coordinates + def input_coords(self) -> CoordSystem: + """Input coordinate system of diagnostic model. + + Returns + ------- + CoordSystem + Coordinate system dictionary + """ + # TODO: Define input coordinates + pass + + # 3. Output coordinates + @batch_coords() + def output_coords( + self, input_coords: CoordSystem, + ) -> CoordSystem: + """Output coordinate system of diagnostic model. + + Parameters + ---------- + input_coords : CoordSystem + Input coordinate system to transform + + Returns + ------- + CoordSystem + Output coordinate system + + Raises + ------ + ValueError + If input coordinates are invalid + """ + # TODO: Validate and transform coordinates + pass + + # 4. Forward pass + @torch.inference_mode() + @batch_func() + def __call__( + self, + x: torch.Tensor, + coords: CoordSystem, + ) -> tuple[torch.Tensor, CoordSystem]: + """Forward pass of diagnostic model. + + Parameters + ---------- + x : torch.Tensor + Input tensor + coords : CoordSystem + Input coordinate system + + Returns + ------- + tuple[torch.Tensor, CoordSystem] + Output tensor and coordinates + """ + # TODO: Validate, forward, return + pass + + # 5. Private forward computation (optional) + def _forward( + self, x: torch.Tensor, coords: CoordSystem, + ) -> torch.Tensor: + """Internal forward pass. + + Parameters + ---------- + x : torch.Tensor + Input tensor on device + coords : CoordSystem + Input coordinate system + + Returns + ------- + torch.Tensor + Output tensor + """ + # TODO: normalize -> core_model -> denormalize + pass + + # 6. Device management + def to( + self, device: torch.device | str, + ) -> DiagnosticModel: + """Move model to device. + + Parameters + ---------- + device : torch.device | str + Target device + + Returns + ------- + DiagnosticModel + Model on target device + """ + # TODO: Device management + pass + + # 7. Default package location + @classmethod + def load_default_package(cls) -> Package: + """Default pre-trained model package. + + Returns + ------- + Package + Model package + """ + # TODO: Default checkpoint location + pass + + # 8. Load model from package + @classmethod + @check_optional_dependencies() + def load_model( + cls, package: Package, + ) -> DiagnosticModel: + """Load diagnostic model from package. + + Parameters + ---------- + package : Package + Model package with checkpoint files + + Returns + ------- + DiagnosticModel + Loaded model instance + """ + # TODO: Load model from package + pass +``` + +#### Physics-based skeleton (`torch.nn.Module` only) + +```python +from collections import OrderedDict + +import numpy as np +import torch + +from earth2studio.models.batch import batch_coords, batch_func +from earth2studio.models.dx.base import DiagnosticModel +from earth2studio.utils import handshake_coords, handshake_dim +from earth2studio.utils.type import CoordSystem + + +class ModelName(torch.nn.Module): + """One-line description. + + Extended description of the calculation, formula, + and any physical basis. + + Parameters + ---------- + levels : list[int | str], optional + Pressure / height levels to compute for + ...additional params... + + Note + ---- + For more information see: <link to reference> + """ + + # 1. Constructor + def __init__( + self, levels: list[int | str] = [100], + ) -> None: + super().__init__() + self.levels = levels + self.in_variables = [...] # Built from levels + self.out_variables = [...] # Built from levels + + # 2. Input coordinates + def input_coords(self) -> CoordSystem: + """Input coordinate system of diagnostic model. + + Returns + ------- + CoordSystem + Coordinate system dictionary + """ + # TODO: Define input coordinates + pass + + # 3. Output coordinates + @batch_coords() + def output_coords( + self, input_coords: CoordSystem, + ) -> CoordSystem: + """Output coordinate system of diagnostic model. + + Parameters + ---------- + input_coords : CoordSystem + Input coordinate system to transform + + Returns + ------- + CoordSystem + Output coordinate system + + Raises + ------ + ValueError + If input coordinates are invalid + """ + # TODO: Validate and transform coordinates + pass + + # 4. Forward pass + @torch.inference_mode() + @batch_func() + def __call__( + self, + x: torch.Tensor, + coords: CoordSystem, + ) -> tuple[torch.Tensor, CoordSystem]: + """Forward pass of diagnostic model. + + Parameters + ---------- + x : torch.Tensor + Input tensor + coords : CoordSystem + Input coordinate system + + Returns + ------- + tuple[torch.Tensor, CoordSystem] + Output tensor and coordinates + """ + # TODO: Physics computation + pass + + # 5. Device management + def to( + self, device: torch.device | str, + ) -> DiagnosticModel: + """Move model to device. + + Parameters + ---------- + device : torch.device | str + Target device + + Returns + ------- + DiagnosticModel + Model on target device + """ + super().to(device) + return self +``` + +### Canonical method ordering for diagnostic models + +Methods in the class **must** appear in this order: + +1. `__init__` — constructor +2. `input_coords` — input coordinate system +3. `output_coords` — output coordinate system (decorated `@batch_coords()`) +4. `__call__` — forward pass (decorated `@torch.inference_mode()`, `@batch_func()`) +5. `_forward` — private computation method (optional helper) +6. `to` — device management +7. `load_default_package` — classmethod returning default `Package` (NN-based only) +8. `load_model` — classmethod loading model from package (NN-based only) + +### Key differences from prognostic models + +Be aware of these critical differences from the +`create-prognostic-wrapper` skill: + +- **No `PrognosticMixin`** — diagnostic models do NOT + inherit from `PrognosticMixin` (no iterator hooks needed) +- **No `create_iterator`** or `_default_generator` — + diagnostic models do NOT perform time integration +- **Typically no `lead_time`** — most diagnostic models do + not have a `lead_time` dimension. However, some models + that consume forecast output at specific lead times (e.g., + solar radiation, wind gust) may include `lead_time` in + their coordinate system. Only add it if the model + genuinely needs temporal context +- **Match `handshake_dim` indices to coordinate position** — + use the actual position of each dimension in the input + `CoordSystem` OrderedDict. For simple models with + `(batch, variable, lat, lon)` the indices are `1, 2, 3`. + For models with `(batch, time, variable, lat, lon)` use + `2, 3, 4`. You may also use negative indices (`-3, -2, -1`) + if the model must accept flexible leading dimensions. + Check existing models in `earth2studio/models/dx/` for + the predominant convention used in the codebase. + +### **[CONFIRM — Skeleton]** + +Present to the user: + +1. The proposed class name +2. The proposed file name and path +3. The detected model type (NN-based or physics-based) +4. Ask if these are acceptable + +--- + +## Step 4 — Implement Coordinate System + +### 4a. Map variables to E2STUDIO_VOCAB + +Read `earth2studio/lexicon/base.py` and verify every +variable the model uses exists in `E2STUDIO_VOCAB`. +The vocab contains 282 entries including: + +| Category | Examples | +|---|---| +| Surface wind | `u10m`, `v10m`, `ws10m`, `u100m`, `v100m` | +| Surface temp | `t2m`, `d2m`, `sst`, `skt` | +| Humidity | `r2m`, `q2m`, `tcwv` | +| Pressure | `sp`, `msl` | +| Precipitation | `tp`, `lsp`, `cp`, `tp06` | +| Pressure-level | `u50`-`u1000`, `v50`-`v1000`, `z50`-`z1000` | +| Cloud/radiation | `tcc`, `rlut`, `rsut` | + +Pressure levels available: 50, 100, 150, 200, 250, +300, 400, 500, 600, 700, 850, 925, 1000. + +If a variable in the reference model does NOT exist +in `E2STUDIO_VOCAB`, flag it to the user and discuss +whether to: + +- Map it to an existing vocab entry +- Propose adding a new vocab entry (separate step) + +### 4b. Implement input_coords + +**NN-based** (fixed grid): + +```python +def input_coords(self) -> CoordSystem: + """Input coordinate system of diagnostic model. + + Returns + ------- + CoordSystem + Coordinate system dictionary + """ + return OrderedDict({ + "batch": np.empty(0), # MUST be first, MUST be np.empty(0) + "time": np.empty(0), # Dynamic time dimension (optional) + "variable": np.array(VARIABLES), + "lat": np.linspace(90, -90, num_lat, endpoint=...), # From reference + "lon": np.linspace(0, 360, num_lon, endpoint=False), # From reference + }) +``` + +**Physics-based** (flexible grid): + +```python +def input_coords(self) -> CoordSystem: + """Input coordinate system of diagnostic model. + + Returns + ------- + CoordSystem + Coordinate system dictionary + """ + return OrderedDict({ + "batch": np.empty(0), # MUST be first, MUST be np.empty(0) + "variable": np.array(self.in_variables), + "lat": np.empty(0), # Flexible — accepts any grid + "lon": np.empty(0), # Flexible — accepts any grid + }) +``` + +**Rules:** + +- `batch` is always first with `np.empty(0)` +- `time` is `np.empty(0)` (dynamic) — include only if + the model needs time information (e.g., for solar + zenith angle calculations) +- **Typically no `lead_time`** — most diagnostic models do + not include `lead_time`. Add it only if the model needs + temporal context (e.g., solar radiation models that need + time-of-day, or models that consume forecast output at + specific lead times) +- `lat` typically goes from 90 to -90 (north to south) + for NN-based; use `np.empty(0)` for physics-based + (flexible grid) +- `lon` typically goes from 0 to 360 for NN-based; + use `np.empty(0)` for physics-based (flexible grid) + +### 4c. Implement output_coords + +```python +@batch_coords() +def output_coords(self, input_coords: CoordSystem) -> CoordSystem: + """Output coordinate system of diagnostic model. + + Parameters + ---------- + input_coords : CoordSystem + Input coordinates to validate and transform + + Returns + ------- + CoordSystem + Output coordinates with updated variable list + + Raises + ------ + ValueError + If input coordinates are invalid + """ + target_input_coords = self.input_coords() + + # Validate dimensions exist at correct indices + # Use the positional index matching each dim's position in input_coords + # For (batch, variable, lat, lon) → 1, 2, 3 + # For (batch, time, variable, lat, lon) → 2, 3, 4 + # Negative indices (-3, -2, -1) also work for flexible leading dims + handshake_dim(input_coords, "variable", 1) + handshake_dim(input_coords, "lat", 2) + handshake_dim(input_coords, "lon", 3) + + # Validate coordinate values match + handshake_coords(input_coords, target_input_coords, "variable") + handshake_coords(input_coords, target_input_coords, "lat") + handshake_coords(input_coords, target_input_coords, "lon") + + output_coords = input_coords.copy() + output_coords["variable"] = np.array(OUTPUT_VARIABLES) + return output_coords +``` + +**Key points:** + +- Use `@batch_coords()` decorator +- Use `handshake_dim` with indices matching each + dimension's position in the `CoordSystem` OrderedDict. + Use positive indices (e.g., `1, 2, 3`) or negative + indices (e.g., `-3, -2, -1`) — match the convention + of the most similar existing model in the codebase. +- Validate coordinate values with `handshake_coords` +- Output typically changes the `variable` array (different + output variables from input variables) +- Unlike prognostic models, there is typically no + `lead_time` to increment + +### **[CONFIRM — Coordinates]** + +Present to the user: + +1. The input and output variable lists and any mapping + issues with `E2STUDIO_VOCAB` +2. The spatial dimensions (lat/lon grid size and spacing, + or "flexible" for physics-based) +3. Whether `lead_time` is needed (and why — most dx + models omit it) +4. Whether `time` is included (and why, if so) + +--- + +## Step 5 — Implement Forward Pass + +### 5a. Implement `__call__` + +```python +@torch.inference_mode() +@batch_func() +def __call__( + self, + x: torch.Tensor, + coords: CoordSystem, +) -> tuple[torch.Tensor, CoordSystem]: + """Forward pass of diagnostic model. + + Parameters + ---------- + x : torch.Tensor + Input tensor + coords : CoordSystem + Input coordinate system + + Returns + ------- + tuple[torch.Tensor, CoordSystem] + Output tensor and coordinates + """ + # Get output coordinates (this also validates input coords) + output_coords = self.output_coords(coords) + + # Move to device + device = self.device_buffer.device # NN-based + # or: device = next(self.parameters()).device + x = x.to(device) + + # Run forward pass + out = self._forward(x, coords) + + return out, output_coords +``` + +Key implementation notes: + +- The `@batch_func()` decorator handles the batch + dimension — inside `__call__`, `x` has shape + `(batch, ..., variable, lat, lon)` where `...` may + include `time` and/or `lead_time` depending on the + upstream pipeline +- Coordinate validation is handled inside `output_coords()` + — no need to repeat `handshake_dim`/`handshake_coords` + in `__call__` (follow the pattern of existing dx models) +- All tensor operations should happen on GPU when + possible +- **No `create_iterator`**, no hooks, no yielding — + diagnostic models perform a single forward pass only + +### 5b. Implement `_forward` + +**NN-based** (normalize → model → denormalize): + +```python +def _forward( + self, x: torch.Tensor, coords: CoordSystem, +) -> torch.Tensor: + """Internal forward pass.""" + # Normalize input + x = (x - self.center) / self.scale + + # Reshape for core model if needed + # TODO: Reshape from (..., variable, lat, lon) to model format + + # Core model forward + out = self.core_model(x) + + # Reshape output back + # TODO: Reshape from model format to (..., out_variable, lat, lon) + + # Denormalize output (if needed) + out = out * self.output_scale + self.output_center + + return out +``` + +**Physics-based** (direct torch math): + +```python +@torch.inference_mode() +@batch_func() +def __call__( + self, + x: torch.Tensor, + coords: CoordSystem, +) -> tuple[torch.Tensor, CoordSystem]: + """Forward pass of diagnostic.""" + output_coords = self.output_coords(coords) + + # Example: wind speed from u, v components + u = x[..., ::2, :, :] # Every other variable starting from 0 + v = x[..., 1::2, :, :] # Every other variable starting from 1 + out = torch.sqrt(u**2 + v**2) + + return out, output_coords +``` + +For physics-based models, the computation typically +goes directly in `__call__` rather than a separate +`_forward` method. + +### **[CONFIRM — Forward Pass]** + +Show the user the pseudocode for `__call__` +(especially the reshape logic for NN-based, or the +formula for physics-based). Ask: + +1. Does the computation logic look correct for this + model? +2. Are there any special considerations (e.g., multiple + ONNX models, clamping, special tensor layouts)? + +--- + +## Step 6 — Implement Model Loading (NN-based only) + +**Physics-based models:** Skip this step entirely. +State explicitly: "Step 6 skipped — physics-based models +have no checkpoints to load." + +### 6a. Implement load_default_package + +```python +@classmethod +def load_default_package(cls) -> Package: + """Default pre-trained model package on <source>. + + Returns + ------- + Package + Model package + """ + return Package( + "ngc://models/org/model@version", # or hf://, s3://, local path + cache_options={ + "cache_storage": Package.default_cache("model_name"), + "same_names": True, + }, + ) +``` + +### 6b. Implement load_model + +```python +@classmethod +@check_optional_dependencies() +def load_model( + cls, + package: Package, +) -> DiagnosticModel: + """Load diagnostic model from package. + + Parameters + ---------- + package : Package + Model package with checkpoint files + + Returns + ------- + DiagnosticModel + Loaded model instance + """ + # Resolve checkpoint files + checkpoint_path = package.resolve("model.pt") + + # Load model + core_model = torch.load( + checkpoint_path, map_location="cpu", weights_only=False, + ) + core_model.eval() + core_model.requires_grad_(False) + + # Load any additional data files (normalization, masks, etc.) + # center = torch.Tensor(np.load(package.resolve("center.npy"))) + # scale = torch.Tensor(np.load(package.resolve("scale.npy"))) + + return cls(core_model) +``` + +**Key patterns:** + +- Use `package.resolve("filename")` to get cached + file paths +- Load with `map_location="cpu"` then let user call + `.to(device)` +- Set model to `eval()` mode and `requires_grad_(False)` +- Do NOT over-populate `load_model()` API — only + expose essential parameters +- Use `@check_optional_dependencies()` if the model + has optional deps + +### 6c. Implement .to() + +> **Note:** When the wrapper inherits from `torch.nn.Module`, +> `super().to(device)` already handles moving all registered +> parameters, buffers, and sub-modules. A custom `to()` +> override is only needed when there is non-PyTorch state to +> manage (e.g., ONNX Runtime sessions that must be destroyed +> and recreated on a new device, or JAX device placement). +> If `super().to(device)` is sufficient, you can omit the +> override entirely. + +```python +def to(self, device: torch.device | str) -> DiagnosticModel: + """Move model to device. + + Parameters + ---------- + device : torch.device | str + Target device + + Returns + ------- + DiagnosticModel + Model on target device + """ + super().to(device) + # If using ONNX Runtime, destroy and recreate session on new device + # If using PyTorch, super().to(device) handles it + return self +``` + +### **[CONFIRM — Model Loading]** + +Present to the user: + +1. The checkpoint URL/path for `load_default_package` +2. The checkpoint file names and loading logic +3. Whether there are multiple checkpoint files +4. The `.to()` implementation (especially if ONNX or + non-PyTorch backend) + +--- + +## Step 7 — Register the Model + +### 7a. Add to `__init__.py` + +Edit `earth2studio/models/dx/__init__.py`: + +- Add import in alphabetical order: + `from earth2studio.models.dx.<filename> import <ClassName>` +- Add `<ClassName>` to the `__all__` list in alphabetical order + +### 7b. Verify pyproject.toml (NN-based only) + +Confirm the dependency group was added in Step 2 and is included in the `all` aggregate. + +For physics-based models, no pyproject.toml verification is needed. + +--- + +## Step 8 — Verify Style, Documentation, Format & Lint + +Before testing, verify the wrapper passes all code quality checks. + +### 8a. Run formatting + +```bash +make format +``` + +This runs `black` on the codebase. Fix any formatting issues in the new wrapper file and test file. + +### 8b. Run linting + +```bash +make lint +``` + +This runs `ruff` and `mypy`. Common issues to watch +for: + +- Missing type annotations on public functions +- Unused imports +- Import ordering issues +- Type errors from incorrect return types or missing + `CoordSystem` annotations + +Fix all errors before proceeding. + +### 8c. Check license headers + +```bash +make license +``` + +Verify that both the wrapper file +(`earth2studio/models/dx/<filename>.py`) and the test +file (`test/models/dx/test_<filename>.py`) have the +correct SPDX Apache-2.0 license header. + +### 8d. Verify documentation + +Check that: + +- The class docstring follows NumPy-style formatting + with `Parameters`, `Note`, etc. +- All public methods (`__call__`, `input_coords`, + `output_coords`, `to`, and for NN-based: + `load_default_package`, `load_model`) have complete + docstrings with `Parameters`, `Returns`, `Raises` + sections as applicable +- Type hints are present on all public method + signatures + +If any checks fail, fix the issues and re-run until all pass cleanly. + +--- + +## Step 9 — Test Forward Pass with Random Data + +Write and run a quick smoke test script. + +**Note:** Diagnostic models do NOT have `create_iterator` — only test `__call__`. + +```python +import torch +import numpy as np +from earth2studio.models.dx import ModelName + +# Load model (or construct with dummy weights for testing) +model = ModelName(...) # Use dummy/test weights if real ones aren't available +model = model.to("cuda" if torch.cuda.is_available() else "cpu") + +# Get input coords +input_coords = model.input_coords() + +# Create random input tensor +shape = tuple(max(len(v), 1) for v in input_coords.values()) +x = torch.randn(shape) + +# Test __call__ +output, out_coords = model(x, input_coords) +print(f"Input shape: {x.shape}") +print(f"Output shape: {output.shape}") +print(f"Input variables: {input_coords['variable']}") +print(f"Output variables: {out_coords['variable']}") +``` + +Report results to the user. There is no `create_iterator` +to test — diagnostic models perform a single forward pass. + +--- + +## Step 10 — Test Data Fetch with Random Source + +Test that the model's coordinate system works with Earth2Studio's data pipeline: + +```python +import numpy as np +from earth2studio.data import Random, fetch_data +from earth2studio.models.dx import ModelName + +model = ModelName(...) + +# Create time array +time = np.array([np.datetime64("2024-01-01T00:00")]) + +# Fetch data using model's input coords +input_coords = model.input_coords() +input_coords["time"] = time +ds = Random(input_coords) +x, coords = fetch_data(ds, time, input_coords["variable"]) + +print(f"Fetched data shape: {x.shape}") +print(f"Variables: {input_coords['variable']}") +``` + +Report results to the user. This validates the coordinate system is compatible with the data pipeline. + +--- + +## Step 11 — Write Pytest Unit Tests + +Create `test/models/dx/test_<filename>.py` following the existing test patterns. + +### 11a. Test file structure + +```python +# License header (same SPDX Apache-2.0 header as above) + +from collections import OrderedDict + +import numpy as np +import pytest +import torch + +from earth2studio.data import Random, fetch_data +from earth2studio.models.auto import Package +from earth2studio.models.dx import ModelName +from earth2studio.utils import handshake_dim + + +class PhooModelName(torch.nn.Module): + """Dummy model that performs a simple deterministic operation.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + return x[..., :NUM_OUTPUT_VARS, :, :] # Simple slice for testing + + +@pytest.fixture(scope="class") +def test_package(tmp_path_factory): + """Create a dummy model package for testing.""" + tmp_path = tmp_path_factory.mktemp("data") + # Export dummy model to checkpoint format + model = PhooModelName() + torch.save(model, tmp_path / "model.pt") + # Save any additional files (normalization, etc.) + # np.save(tmp_path / "center.npy", np.zeros(NUM_VARS)) + return Package(str(tmp_path)) + + +class TestModelNameMock: + @pytest.mark.parametrize( + "time", + [ + np.array([ + np.datetime64("1999-10-11T12:00"), + np.datetime64("2001-06-04T00:00"), + ]), + ], + ) + @pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="No GPU" + ), + ), + ], + ) + def test_model_call(self, test_package, time, device): + """Test single forward pass.""" + # NN-based: use load_model + model = ModelName.load_model(test_package) + # Physics-based: construct directly instead + # model = ModelName(levels=[1000]) + model = model.to(device) + + # Fetch input data + dc = model.input_coords() + dc["time"] = time + ds = Random(dc) + x, coords = fetch_data(ds, time, dc["variable"], device=device) + + # Run forward + out, out_coords = model(x, coords) + + # Validate output + assert isinstance(out_coords, OrderedDict) + handshake_dim(out_coords, "variable", 1) # Adjust index to match model's coord ordering + # Additional model-specific assertions + # e.g., assert output variable count is correct + + @pytest.mark.parametrize( + "coords", + [ + OrderedDict({ + "batch": np.empty(0), + "variable": np.array(["wrong_var"]), + "lat": np.linspace(90, -90, 10), + "lon": np.linspace(0, 360, 20), + }), + ], + ) + def test_model_exceptions(self, test_package, coords): + """Test model raises on invalid coordinates.""" + # NN-based: use load_model + model = ModelName.load_model(test_package) + # Physics-based: construct directly instead + # model = ModelName(levels=[1000]) + x = torch.randn( + 1, + len(coords["variable"]), + len(coords["lat"]), + len(coords["lon"]), + ) + with pytest.raises((KeyError, ValueError)): + model(x, coords) +``` + +**Note:** There is NO `test_model_iter` — diagnostic +models do not have `create_iterator`. + +**Physics-based models — add `test_physics_correctness`:** + +```python +def test_physics_correctness(self): + """Verify physics formula against known values.""" + model = ModelName(levels=[1000]) + + # Create known input values + # e.g., for wind speed: u=3, v=4 → ws=5 + input_coords = model.input_coords() + x = torch.zeros(1, len(input_coords["variable"]), 5, 5) + x[:, 0, :, :] = 3.0 # u component + x[:, 1, :, :] = 4.0 # v component + + out, _ = model(x, input_coords) + + expected = torch.full_like(out, 5.0) # sqrt(3² + 4²) = 5 + assert torch.allclose(out, expected, rtol=1e-5) +``` + +**NN-based models — add `@pytest.mark.package` integration test:** + +```python +@pytest.mark.package +def test_model_package(): + """Integration test with real model weights.""" + model = ModelName.from_pretrained() + input_coords = model.input_coords() + time = np.array([np.datetime64("2024-01-01T00:00")]) + input_coords["time"] = time + ds = Random(input_coords) + x, coords = fetch_data(ds, time, input_coords["variable"]) + model = model.to("cuda" if torch.cuda.is_available() else "cpu") + x = x.to(model.device_buffer.device) + out, out_coords = model(x, coords) + # Validate output shape and variables + assert out_coords["variable"].shape[0] > 0 +``` + +Adapt the dummy model (`PhooModelName`) to match the +actual model's input/output interface so the wrapper's +reshaping logic is exercised. + +--- + +## Step 11b — Run Tests + +### 11b-i. Run mock tests (no package flag) + +First, run the unit tests that use mocked / dummy models. +These do **not** require downloading real checkpoints and +should run quickly on any machine: + +```bash +uv run python -m pytest test/models/dx/test_<filename>.py \ + -m "not package" -v +``` + +All mock tests must pass before proceeding. Fix any +failures and re-run until green. + +### 11b-ii. Run the package integration test (NN-based only) + +Once the mock tests pass, run the `@pytest.mark.package` +test which exercises `from_pretrained()` with real model +weights: + +```bash +uv run python -m pytest test/models/dx/test_<filename>.py \ + -m "package" -v +``` + +### **[CONFIRM — Package Test]** + +Before executing the package test, warn the user: + +> The package / integration test will: +> +> - **Download the model checkpoint** (may be several GB) +> to the local cache +> - **Require GPU compute** for models that need CUDA +> (the test will skip on CPU-only machines if +> `torch.cuda.is_available()` is `False`) +> - Take significantly longer than the mock tests +> +> Do you want to proceed with the package test? + +Only run the package test after the user confirms. Report +the results back to the user. If the package test fails, +debug and fix the wrapper or test, then re-run. + +For physics-based models, there is no package test to run. +State: "No package test — physics-based model has no +checkpoints." + +--- + +## Step 12 — Provide Side-by-Side Comparison Scripts + +Present two scripts to the user: + +### Reference script (without Earth2Studio) + +Reconstruct a minimal inference script based on the original reference code: + +**NN-based:** + +```python +# Reference inference (no Earth2Studio) +import torch +# ... original model imports ... + +# Load model +model = OriginalModel.from_pretrained("path/to/checkpoint") +model.eval().cuda() + +# Prepare input +input_data = ... # Load/prepare input data per original repo instructions + +# Run inference +with torch.no_grad(): + output = model(input_data) + +print(f"Output shape: {output.shape}") +``` + +**Physics-based:** + +```python +# Reference calculation (no Earth2Studio) +import torch + +# Hand-calculated / reference values +u = torch.tensor([3.0, 0.0, 1.0]) +v = torch.tensor([4.0, 5.0, 0.0]) + +# Direct formula +ws = torch.sqrt(u**2 + v**2) + +print(f"Wind speed: {ws}") # Expected: [5.0, 5.0, 1.0] +``` + +### Earth2Studio equivalent + +**NN-based:** + +```python +# Earth2Studio inference +import torch +import numpy as np +from earth2studio.models.dx import ModelName +from earth2studio.data import Random, fetch_data + +# Load model +model = ModelName.from_pretrained() +model = model.to("cuda") + +# Prepare input via Earth2Studio data pipeline +time = np.array([np.datetime64("2024-01-01T00:00")]) +input_coords = model.input_coords() +input_coords["time"] = time +ds = Random(input_coords) # Replace with real data source +x, coords = fetch_data(ds, time, input_coords["variable"], device="cuda") + +# Single forward pass (no iterator — this is a diagnostic model) +with torch.no_grad(): + output, out_coords = model(x, coords) + +print(f"Output shape: {output.shape}") +print(f"Output variables: {out_coords['variable']}") +``` + +**Physics-based:** + +```python +# Earth2Studio inference +import torch +import numpy as np +from earth2studio.models.dx import ModelName + +# Construct model (no checkpoint needed) +model = ModelName(levels=[1000]) + +# Prepare input +input_coords = model.input_coords() +x = torch.zeros(1, len(input_coords["variable"]), 5, 5) +x[:, 0, :, :] = 3.0 # u component +x[:, 1, :, :] = 4.0 # v component + +# Single forward pass +output, out_coords = model(x, input_coords) + +print(f"Output shape: {output.shape}") +print(f"Output variables: {out_coords['variable']}") +print(f"Output values (should be 5.0): {output.mean():.1f}") +``` + +### **[CONFIRM — Comparison Scripts]** + +Ask the user to compare the two scripts and verify the +Earth2Studio version is functionally equivalent to the +reference. + +For physics-based models, also ask the user to verify +the output matches hand-calculated expected values. + +--- + +## Reminders + +- **DO NOT** make a general base class with intent to reuse the wrapper across models +- **DO NOT** over-populate the `load_model()` API — only expose essential parameters +- **DO NOT** add `lead_time` dimension unless the model genuinely needs temporal context + (e.g., solar radiation, wind gust models that depend on forecast lead time) +- **DO NOT** add `create_iterator` or `_default_generator` — diagnostic models are single-pass +- **DO NOT** inherit from `PrognosticMixin` — diagnostic models do not need iterator hooks +- **DO** use `handshake_dim` indices matching each dimension's position in the + `CoordSystem` OrderedDict — check existing dx models for the predominant convention +- **DO** add the model to `docs/modules/models.rst` + in the `earth2studio.models.dx` section + (alphabetical order) +- **DO** use `loguru.logger` for logging, never `print()`, inside `earth2studio/` +- **DO** ensure all public functions have full type hints +- **DO** run formatting (`make format`) and linting (`make lint`) before finalizing +- **DO** use `@torch.inference_mode()` on `__call__` for inference-only models +- **DO** set `eval()` and `requires_grad_(False)` on loaded NN models +- **DO** use `@batch_func()` on `__call__` and `@batch_coords()` on `output_coords` +- **DO** validate coordinates with `handshake_coords()` and `handshake_dim()` +- **DO** move tensors to device before operations: `x.to(device)` +- **DO** use `uv run python` for all Python commands (never bare `python`) diff --git a/.claude/skills/validate-diagnostic-wrapper/SKILL.md b/.claude/skills/validate-diagnostic-wrapper/SKILL.md new file mode 100644 index 000000000..1f0a6ade3 --- /dev/null +++ b/.claude/skills/validate-diagnostic-wrapper/SKILL.md @@ -0,0 +1,831 @@ +--- +name: validate-diagnostic-wrapper +description: Validate a newly created Earth2Studio diagnostic model wrapper by running tests, performing reference comparison, generating sanity-check outputs, and opening a PR with automated code review. Use this skill after completing diagnostic model implementation (create-diagnostic-wrapper skill Steps 0-12). +argument-hint: Name of the diagnostic model class and test file (optional — will be inferred from recent changes if not provided) +--- + +# Validate Diagnostic Model Wrapper + +Validate a newly created Earth2Studio diagnostic model (dx) wrapper by +running tests, performing reference comparison, generating sanity-check +outputs, and opening a PR with automated code review. This skill picks +up after the `create-diagnostic-wrapper` skill completes implementation +(Steps 0-12). + +> **Python Environment:** This project uses **uv** for dependency +> management. Always use the local `.venv` virtual environment +> (`source .venv/bin/activate` or prefix with `uv run python`) for all +> Python commands — installing packages, running tests, executing +> scripts, etc. Use `uv add` / `uv pip install` / `uv lock` instead of +> `pip install`. + +Each confirmation gate marked by: + +```markdown +### **[CONFIRM — <Title>]** +``` + +requires **explicit user approval** before proceeding. + +--- + +## Step 1 — Run Tests + +### 1a. Run the new test file + +```bash +uv run python -m pytest test/models/dx/test_<filename>.py -v --timeout=60 +``` + +All tests must pass. Fix failures and re-run until green. + +### 1b. Run coverage report with `--slow` tests + +Run the new test file **with coverage** and the `--slow` flag to +include integration tests. The new diagnostic model file must achieve +**at least 90% line coverage**: + +```bash +uv run python -m pytest test/models/dx/test_<filename>.py -v \ + --slow --timeout=300 \ + --cov=earth2studio/models/dx/<filename> \ + --cov-report=term-missing \ + --cov-fail-under=90 +``` + +- `--slow` enables integration tests (marked `@pytest.mark.slow`) +- `--cov=earth2studio/models/dx/<filename>` scopes coverage to the + new model module only +- `--cov-report=term-missing` shows which lines are not covered +- `--cov-fail-under=90` fails the run if coverage is below 90% + +If coverage is below 90%, add additional tests or mock tests to +cover the missing lines. Common coverage gaps for dx models: + +- Error handling in `output_coords` (wrong variable names, wrong dims) +- Device management paths (CPU vs CUDA) +- Edge cases in physics calculations (zero inputs, extreme values) +- `load_model` and `load_default_package` (NN-based models, needs mock) + +Re-run until coverage is at or above 90%. + +### 1c. Run the full model test suite (optional but recommended) + +```bash +make pytest TOX_ENV=test-models +``` + +Confirm no regressions in existing model tests. + +--- + +## Step 2 — Reference Comparison & Sanity-Check + +This step validates the diagnostic model wrapper produces correct +output by comparing against the original reference implementation and +generating visual sanity-check plots. + +### 2a. Create reference comparison script + +Create a **standalone Python script** in the repo root. This is for +validation only and should **NOT** be committed to the repo. + +The script loads the reference model and the E2S wrapper side by side, +runs both on identical input (same random seed or real data), and +compares outputs with tolerance: + +```python +"""Reference comparison for <ModelName> diagnostic model. + +Compares the Earth2Studio wrapper output against the original reference +implementation to verify numerical agreement. + +This script is for validation only — do NOT commit to the repo. +""" +import torch +import numpy as np + +# --- Reference model --- +# TODO: Load original model per reference repo instructions +# Uncomment and adapt the following lines: +# ref_model = ... +# ref_input = ... +# ref_output = ref_model(ref_input) +raise NotImplementedError( + "Fill in the reference model code above, then remove this line." +) + +# --- Earth2Studio wrapper --- +from earth2studio.models.dx import ModelName + +model = ModelName(...) # or ModelName.load_model(package) +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +input_coords = model.input_coords() +# Construct input tensor matching the reference input +# Use the same random seed or identical real data for both +shape = tuple(max(len(v), 1) for v in input_coords.values()) +torch.manual_seed(42) +x = torch.randn(shape, device=device) + +e2s_output, out_coords = model(x, input_coords) + +# --- Compare outputs --- +# Ensure both tensors are on the same device and dtype +ref_output = ref_output.to(e2s_output.device) + +max_abs_diff = (ref_output - e2s_output).abs().max().item() +max_rel_diff = ( + (ref_output - e2s_output).abs() / (ref_output.abs() + 1e-8) +).max().item() +correlation = torch.corrcoef( + torch.stack([ref_output.flatten(), e2s_output.flatten()]) +)[0, 1].item() + +print(f"Max absolute difference: {max_abs_diff:.2e}") +print(f"Max relative difference: {max_rel_diff:.2e}") +print(f"Correlation: {correlation:.8f}") + +assert torch.allclose(ref_output, e2s_output, rtol=1e-4, atol=1e-5), \ + f"Output mismatch! Max abs diff: {max_abs_diff:.2e}" + +print("PASS: Reference comparison successful.") +``` + +**For physics-based models**, compare against hand-calculated expected +values instead of a reference model: + +```python +# Known test case: e.g., u=3, v=4 -> wind_speed=5 +u_input = torch.tensor([3.0]) +v_input = torch.tensor([4.0]) +expected_ws = torch.tensor([5.0]) + +# Run through E2S wrapper +# ... +assert torch.allclose(e2s_output, expected_ws, atol=1e-6), \ + f"Physics check failed: expected {expected_ws}, got {e2s_output}" +``` + +### 2b. Summarize model capabilities to user + +Before generating sanity-check plots, **present a summary table** to +the user covering the model's capabilities: + +> **Model Summary for `<ClassName>`:** +> +> | Property | Value | +> |---|---| +> | **Input variables** | `var1`, `var2`, ... | +> | **Output variables** | `out1`, `out2`, ... | +> | **Spatial resolution** | X.XX deg x Y.YY deg (NxM) / Flexible | +> | **Checkpoint size** | XX MB / N/A (physics-based) | +> | **Checkpoint source** | NGC / HuggingFace / N/A | +> | **Inference time** | ~XX ms per forward pass (on GPU/CPU) | + +This summary helps the user verify the wrapper matches their +expectations for the model. + +### 2c. Generate sanity-check plot script + +Create a **standalone Python script** in the repo root. This is for +PR reviewer reference only and should **NOT** be committed to the +repo. + +Choose the appropriate template based on the model's output type: + +#### Spatial gridded outputs (e.g., precipitation, solar radiation) + +```python +"""Sanity-check plot for <ModelName> diagnostic model. + +This script is for PR review only — do NOT commit to the repo. +Run it to produce a quick visualization confirming the model works. +""" +import matplotlib.pyplot as plt +import numpy as np +import torch + +from earth2studio.models.dx import ModelName + +# Load model +model = ModelName(...) # or ModelName.load_model(package) +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +# Create or fetch input data +input_coords = model.input_coords() +shape = tuple(max(len(v), 1) for v in input_coords.values()) +x = torch.randn(shape, device=device) + +# Run forward pass +output, out_coords = model(x, input_coords) +output_np = output.cpu().numpy() + +# Plot contourf for each output variable +variables = list(out_coords["variable"]) +n_vars = len(variables) +fig, axes = plt.subplots(1, n_vars, figsize=(6 * n_vars, 5)) +if n_vars == 1: + axes = [axes] + +for ax, i_var in zip(axes, range(n_vars)): + var = variables[i_var] + data_2d = output_np[0, 0, i_var, :, :] # batch=0, time=0 + im = ax.contourf(data_2d, cmap="turbo", levels=20) + ax.set_title(f"{var}") + plt.colorbar(im, ax=ax, shrink=0.8) + +plt.suptitle(f"<ModelName> diagnostic output", y=1.02) +plt.tight_layout() +plt.savefig("sanity_check_<model_name>.png", dpi=150, bbox_inches="tight") +print("Saved: sanity_check_<model_name>.png") +``` + +#### Physics-based outputs (e.g., derived quantities like wind speed) + +```python +"""Sanity-check for <ModelName> physics-based diagnostic. + +This script is for PR review only — do NOT commit to the repo. +Validates exact physics results against known test cases. +""" +import matplotlib.pyplot as plt +import numpy as np +import torch + +from earth2studio.models.dx import ModelName + +model = ModelName(...) + +# Known test cases +# e.g., u=3, v=4 -> wind_speed=5 +test_inputs = { + "u": [0.0, 3.0, -5.0, 10.0], + "v": [0.0, 4.0, 12.0, 0.0], +} +expected_outputs = { + "ws": [0.0, 5.0, 13.0, 10.0], +} + +# Construct input tensor from test cases and run model +input_coords = model.input_coords() +n_cases = len(test_inputs["u"]) +n_vars = len(input_coords["variable"]) +# Build tensor: shape (1, n_vars, n_cases, 1) — batch=1, spatial=n_cases x 1 +x = torch.zeros(1, n_vars, n_cases, 1) +# Map test inputs to the correct variable channels +# Adapt these indices to match model.input_coords()["variable"] +var_list = list(input_coords["variable"]) +x[:, var_list.index("u1000"), :, 0] = torch.tensor(test_inputs["u"]) +x[:, var_list.index("v1000"), :, 0] = torch.tensor(test_inputs["v"]) + +# Update coords to match tensor shape +test_coords = input_coords.copy() +test_coords["lat"] = np.arange(n_cases, dtype=np.float64) +test_coords["lon"] = np.array([0.0]) + +output, out_coords = model(x, test_coords) + +# Validate physics +for key, expected in expected_outputs.items(): + actual = output[..., :, 0, 0].cpu().numpy().flatten() + expected_arr = np.array(expected) + print(f"{key}: expected={expected_arr}, actual={actual}") + assert np.allclose(actual, expected_arr, atol=1e-6), \ + f"Physics validation failed for {key}" + +print("PASS: All physics test cases validated.") + +# Side-by-side plot: input vs output +fig, axes = plt.subplots(1, 2, figsize=(12, 5)) +axes[0].bar(range(len(test_inputs["u"])), test_inputs["u"], label="u") +axes[0].bar(range(len(test_inputs["v"])), test_inputs["v"], + alpha=0.7, label="v") +axes[0].set_title("Input components") +axes[0].legend() + +axes[1].bar(range(len(expected_outputs["ws"])), expected_outputs["ws"], + label="Expected", alpha=0.7) +axes[1].bar(range(len(expected_outputs["ws"])), + output[..., :, 0, 0].cpu().numpy().flatten(), + label="E2S output", alpha=0.5) +axes[1].set_title("Output: expected vs actual") +axes[1].legend() + +plt.suptitle("<ModelName> — physics validation") +plt.tight_layout() +plt.savefig("sanity_check_<model_name>.png", dpi=150, bbox_inches="tight") +print("Saved: sanity_check_<model_name>.png") +``` + +#### Scalar/classification outputs (e.g., TC tracking, severity index) + +```python +"""Sanity-check for <ModelName> scalar/classification diagnostic. + +This script is for PR review only — do NOT commit to the repo. +Generates histogram and summary statistics of output values. +""" +import matplotlib.pyplot as plt +import numpy as np +import torch + +from earth2studio.models.dx import ModelName + +# Load model and run inference +model = ModelName(...) +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +input_coords = model.input_coords() +shape = tuple(max(len(v), 1) for v in input_coords.values()) +x = torch.randn(shape, device=device) + +output, out_coords = model(x, input_coords) +output_np = output.cpu().numpy().flatten() + +# Summary statistics +print(f"Output shape: {output.shape}") +print(f"Min: {output_np.min():.4f}") +print(f"Max: {output_np.max():.4f}") +print(f"Mean: {output_np.mean():.4f}") +print(f"Std: {output_np.std():.4f}") + +# Histogram of output values +fig, ax = plt.subplots(figsize=(8, 5)) +ax.hist(output_np, bins=50, edgecolor="black", alpha=0.7) +ax.set_xlabel("Output value") +ax.set_ylabel("Count") +ax.set_title(f"<ModelName> output distribution (n={len(output_np)})") +ax.axvline(output_np.mean(), color="red", linestyle="--", label=f"Mean: {output_np.mean():.3f}") +ax.legend() + +plt.tight_layout() +plt.savefig("sanity_check_<model_name>.png", dpi=150, bbox_inches="tight") +print("Saved: sanity_check_<model_name>.png") +``` + +### 2d. Run comparison and sanity-check scripts + +Execute both scripts: + +```bash +uv run python reference_comparison_<model_name>.py +uv run python sanity_check_<model_name>.py +``` + +Verify that: + +- The reference comparison passes (all assertions hold) +- The sanity-check script runs without errors +- Output PNGs are generated +- Metrics are printed (max abs diff, max rel diff, correlation) + +### 2e. **[CONFIRM — Sanity-Check & Comparison]** + +**You MUST ask the user to visually inspect the generated plot(s) +before proceeding.** Do not skip this step even if the scripts ran +without errors — a successful run does not guarantee the plots are +correct (e.g., empty axes, wrong colorbar range, garbled data). + +Tell the user the absolute path to the generated image file(s) and +the reference comparison metrics, then ask them to inspect: + +> The reference comparison and sanity-check scripts ran successfully. +> +> **Reference comparison metrics:** +> +> - Max absolute difference: `<value>` +> - Max relative difference: `<value>` +> - Correlation: `<value>` +> +> **Sanity-check plot saved to:** +> `/absolute/path/to/sanity_check_<model_name>.png` +> +> **Please open this image and confirm it looks correct.** Check: +> +> 1. Data is visible on the axes (not blank/empty) +> 2. Values are in physically reasonable ranges +> 3. No obvious artifacts (all-NaN regions, garbled values) +> 4. For spatial outputs: geographic patterns look plausible +> 5. For physics outputs: results match expected analytical values +> +> Does the plot look correct and do the reference comparison metrics +> look acceptable? + +**Do not proceed to Step 3 until the user explicitly confirms.** If +the user reports problems, debug and fix the issue, re-run the +scripts, and ask the user to inspect again. + +--- + +## Step 3 — Branch, Commit & Open PR + +### **[CONFIRM — Ready to Submit]** + +Before proceeding, confirm with the user: + +> All implementation and validation steps are complete: +> +> - Diagnostic model class implemented with correct method ordering +> - Coordinate system with proper `handshake_dim` indices +> - Forward pass implemented (NN-based or physics-based) +> - Model loading implemented (NN-based) or skipped (physics-based) +> - Registered in `earth2studio/models/dx/__init__.py` +> - Documentation added to `docs/modules/models.rst` +> - Reference URLs included in class docstrings +> - CHANGELOG.md updated +> - Format, lint, and license checks pass +> - Unit tests written and passing with >= 90% coverage +> - Dependencies in pyproject.toml confirmed (NN-based) +> - Reference comparison passes with acceptable tolerance +> - Sanity-check plots generated and confirmed by user +> +> Ready to create a branch, commit, and prepare a PR? + +### 3a. Create branch and commit + +```bash +git checkout -b feat/diagnostic-model-<name> +git add earth2studio/models/dx/<filename>.py \ + earth2studio/models/dx/__init__.py \ + test/models/dx/test_<filename>.py \ + pyproject.toml \ + CHANGELOG.md \ + docs/modules/models.rst +git commit -m "feat: add <ClassName> diagnostic model + +Add <ClassName> diagnostic model for <brief description>. +Includes unit tests and documentation." +``` + +Do **NOT** add the sanity-check scripts, comparison scripts, or +their output images. + +### 3b. Identify the fork remote and push branch + +The working repository is typically a **fork** of +`NVIDIA/earth2studio`. Before pushing, confirm which git remote +points to the user's fork: + +```bash +git remote -v +``` + +Ask the user: + +> Which git remote is your fork of `NVIDIA/earth2studio`? +> (Usually `origin` — e.g., `git@github.com:<user>/earth2studio.git`) + +Then push the feature branch to the **fork** remote: + +```bash +git push -u <fork-remote> feat/diagnostic-model-<name> +``` + +### 3c. Open Pull Request (fork -> NVIDIA/earth2studio) + +> **Important:** PRs must be opened **from the fork** to the +> **upstream source repository** `NVIDIA/earth2studio`. The branch +> lives on the fork; the PR targets `main` on the upstream repo. + +Use `gh pr create` with explicit `--repo` and `--head` flags: + +```bash +gh pr create \ + --repo NVIDIA/earth2studio \ + --base main \ + --head <fork-owner>:feat/diagnostic-model-<name> \ + --title "feat: add <ClassName> diagnostic model" \ + --body "..." +``` + +Where `<fork-owner>` is the GitHub username that owns the fork. + +The PR body should follow this diagnostic-model-specific template: + +````markdown +## Description + +Add `<ClassName>` diagnostic model for <brief description>. + +Closes #<issue_number> (if applicable) + +### Model details + +| Property | Value | +|---|---| +| **Model type** | NN-based / Physics-based | +| **Architecture** | PyTorch / ONNX / Analytical | +| **Input variables** | <list> | +| **Output variables** | <list> | +| **Spatial resolution** | X° x Y° (NxM) / Flexible | +| **Checkpoint source** | NGC / HuggingFace / N/A | +| **Reference** | <paper/repo URL> | + +### Dependencies added + +| Package | Version | License | License URL | Reason | +|---|---|---|---|---| +| `<package>` | `>=X.Y` | <License> | [link](<URL>) | <reason> | + +*(or "No new dependencies — physics-based model")* + +### Reference comparison + +- Max absolute difference: <value> +- Max relative difference: <value> +- Correlation: <value> + +### Validation + +See sanity-check plots in PR comments below. + +## Checklist + +- [x] I am familiar with the [Contributing Guidelines][contrib]. +- [x] New or existing tests cover these changes. +- [x] The documentation is up to date with these changes. +- [x] The [CHANGELOG.md][changelog] is up to date with these changes. +- [ ] An [issue][issues] is linked to this pull request. +- [ ] Assess and address Greptile feedback (AI code review bot). + +[contrib]: https://github.com/NVIDIA/earth2studio/blob/main/CONTRIBUTING.md +[changelog]: https://github.com/NVIDIA/earth2studio/blob/main/CHANGELOG.md +[issues]: https://github.com/NVIDIA/earth2studio/issues +```` + +### 3d. Post sanity-check as PR comment + +After the PR is created, post the sanity-check visualization as a +separate **PR comment** so it is immediately visible to reviewers. + +#### Image upload limitation + +**GitHub has no CLI or REST API for uploading images to PR comments.** +The only way to embed an image is via the browser's drag-and-drop +editor or by referencing an already-hosted URL. + +**Practical workflow:** + +1. Write the comment body to a temp file (avoids shell quoting issues + with heredocs containing backticks and markdown). +2. Post the comment **without** the image — include the validation + table, reference comparison metrics, the full sanity-check script, + and a placeholder line. +3. Tell the user to drag the image into the browser editor. + +```bash +# 1. Write body to a temp file (use your editor tool, not heredoc) + +# 2. Post the comment +gh api -X POST repos/NVIDIA/earth2studio/issues/<PR_NUMBER>/comments \ + -F "body=@/tmp/pr_comment_body.md" \ + --jq '.html_url' +``` + +Do **not** waste time trying `curl` uploads, GraphQL file mutations, +or the `uploads.github.com` asset endpoint — they do not work for +issue/PR comment images. + +#### Comment content template + +```markdown +## Sanity-Check Validation + +**Model:** `<ClassName>` — <brief description> +**Type:** NN-based / Physics-based +**Test environment:** <GPU model or CPU> + +### Reference Comparison + +| Metric | Value | +|--------|-------| +| Max absolute difference | <value> | +| Max relative difference | <value> | +| Correlation | <value> | + +### Model Summary + +| Property | Value | +|----------|-------| +| Input variables | <list or count> | +| Output variables | <list or count> | +| Output shape | <shape> | +| Spatial resolution | X° x Y° / Flexible | +| Inference time | ~XX ms | + +**Key findings:** +- <bullet summarizing numerical agreement with reference> +- <bullet on output quality / physical reasonableness> +- <bullet on performance or notable behavior> + +> **TODO:** Attach sanity-check image by editing this comment in +> the browser. + +<details> +<summary>Sanity-check script (click to expand)</summary> + +```python +PASTE THE FULL WORKING SCRIPT HERE — not a truncated excerpt. +The script must be copy-pasteable and produce the plot end-to-end. +``` + +</details> +``` + +**Important:** Always paste the **complete, runnable** script — not +a shortened version. Reviewers should be able to reproduce the plot +by copying the script directly. + +#### Finalize + +After posting, inform the user of: + +1. The comment URL +2. The local path to the image file for manual attachment +3. Instructions: *"Edit the comment in your browser and drag the + image file into the editor to embed it."* + +> **Note:** The sanity-check image and script are for PR review +> purposes only — they must NOT be committed to the repository. + +--- + +## Step 4 — Automated Code Review (Greptile) + +After the PR is created and pushed, an automated code review from +**greptile-apps** (Greptile) will be posted as PR review comments. +Wait for this review, then process the feedback. + +### 4a. Wait for Greptile review + +Poll for review comments from `greptile-apps[bot]` every 30 seconds +for up to **5 minutes**. Time out gracefully if no review arrives: + +```bash +# Poll loop — check every 30s, timeout after 5 minutes (10 attempts) +for i in $(seq 1 10); do + REVIEW_ID=$(gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/reviews \ + --jq '.[] | select(.user.login == "greptile-apps[bot]") | .id' 2>/dev/null) + if [ -n "$REVIEW_ID" ]; then + echo "Greptile review found: $REVIEW_ID" + break + fi + echo "Attempt $i/10 — no review yet, waiting 30s..." + sleep 30 +done +``` + +If no review after 5 minutes, inform the user: + +> Greptile hasn't posted a review after 5 minutes. This can happen if +> the review bot is busy or the PR hasn't triggered it. You can: +> +> 1. Ask me to check again later +> 2. Skip this step and proceed without automated review +> 3. Manually request a review from Greptile on the PR page + +### 4b. Pull and parse review comments + +Once the review is posted, fetch all comments: + +```bash +# Get all review comments on the PR +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/comments \ + --jq '.[] | select(.user.login == "greptile-apps[bot]") | + {path: .path, line: .diff_hunk, body: .body}' +``` + +Also fetch the top-level review body: + +```bash +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/reviews \ + --jq '.[] | select(.user.login == "greptile-apps[bot]") | .body' +``` + +### 4c. Categorize and present to user + +Parse each comment and categorize it: + +| Category | Description | Default action | +|---|---|---| +| **Bug / correctness** | Logic errors, wrong behavior | Fix | +| **Style / convention** | Naming, formatting, patterns | Fix if valid | +| **Performance** | Inefficiency, resource waste | Evaluate | +| **Documentation** | Missing/wrong docs, docstrings | Fix | +| **Suggestion** | Alternative approach, nice-to-have | User decides | +| **False positive** | Incorrect or irrelevant feedback | Dismiss | + +### **[CONFIRM — Review Triage]** + +Present each comment to the user in a summary table: + +```markdown +| # | File | Line | Category | Summary | Proposed Action | +|---|------|------|----------|---------|-----------------| +| 1 | <model>.py | 142 | Bug | Missing null check | Fix: add guard | +| 2 | <model>.py | 305 | Style | Use f-string | Fix: convert | +| 3 | <model>.py | 45 | Suggestion | Add type alias | Skip: not needed | +| ... | ... | ... | ... | ... | ... | +``` + +For each comment, briefly explain: + +- What Greptile flagged +- Whether you agree or disagree (with reasoning) +- Your proposed fix (or why to skip) + +Ask the user to confirm which comments to address. The user may: + +- Accept all proposed fixes +- Select specific fixes +- Override your recommendation on any comment +- Add their own fixes + +### 4d. Implement fixes + +For each accepted fix: + +1. Make the code change +2. Run `make format && make lint` after all fixes +3. Run the relevant tests: + + ```bash + uv run python -m pytest test/models/dx/test_<filename>.py -v --timeout=60 + ``` + +4. Commit with a message like: + + ```bash + git commit -m "fix: address code review feedback (Greptile)" + ``` + +### 4e. Respond to review comments + +For each Greptile comment, post a reply on the PR: + +**For fixed comments:** + +```bash +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/comments/<COMMENT_ID>/replies \ + -f body="Fixed in <commit_sha>. <brief description of fix>" +``` + +**For dismissed comments (false positives / won't fix):** + +```bash +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/comments/<COMMENT_ID>/replies \ + -f body="Won't fix — <brief justification>" +``` + +### 4f. Push and resolve + +```bash +git push <fork-remote> feat/diagnostic-model-<name> +``` + +After pushing, resolve all addressed review threads if possible. + +Inform the user of the final state: + +- How many comments were fixed +- How many were dismissed (with reasons) +- Any remaining open threads + +--- + +## Reminders + +- **DO** use the repo's local `uv` `.venv` to run Python with + `uv run python` +- **DO NOT** commit sanity-check/comparison scripts or images to + the repo +- **DO** use `loguru.logger` for logging, never `print()`, inside + `earth2studio/` +- **DO** ensure all public functions have full type hints (mypy-clean) +- **DO** maintain alphabetical order in `__init__.py` exports, + RST file entries, and CHANGELOG entries +- **DO** follow the canonical DX method ordering: + `__init__`, `input_coords`, `output_coords`, `__call__`, + `_forward`, `to`, `load_default_package`, `load_model` +- **DO** use `handshake_dim` indices matching each dimension's position in the + `CoordSystem` OrderedDict — check existing dx models for the predominant convention +- **DO** include reference URLs in class docstrings +- **DO** always update CHANGELOG.md under the current unreleased + version +- **DO** add the model to `docs/modules/models.rst` in the + `earth2studio.models.dx` section (alphabetical order) +- **DO NOT** inherit from `PrognosticMixin` or add `create_iterator` + — diagnostic models are single-pass, not time-stepping +- **DO NOT** add `lead_time` dimension unless the model genuinely needs temporal + context (e.g., solar radiation, wind gust models that depend on forecast lead time) +- **DO NOT** make a general base class with intent to reuse the + wrapper across models +- **DO NOT** over-populate the `load_model()` API — only expose + essential parameters +- **NEVER** commit, hardcode, or include API keys, secrets, tokens, + or credentials in source code, sample scripts, commit messages, + PR descriptions, or any file tracked by git From 9e4090eb9780027b69bddd8f9149b76288aa898b Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <ngeneva@nvidia.com> Date: Mon, 13 Apr 2026 22:08:10 +0000 Subject: [PATCH 2/7] feat: add validate-prognostic-wrapper skill Add validation skill for prognostic model (px) wrappers, mirroring the validate-diagnostic-wrapper pattern. Covers test coverage, single-step and multi-step reference comparison, forecast evolution plots, iterator validation, PR submission, and Greptile review. --- .../validate-prognostic-wrapper/SKILL.md | 926 ++++++++++++++++++ 1 file changed, 926 insertions(+) create mode 100644 .claude/skills/validate-prognostic-wrapper/SKILL.md diff --git a/.claude/skills/validate-prognostic-wrapper/SKILL.md b/.claude/skills/validate-prognostic-wrapper/SKILL.md new file mode 100644 index 000000000..83c3456d7 --- /dev/null +++ b/.claude/skills/validate-prognostic-wrapper/SKILL.md @@ -0,0 +1,926 @@ +--- +name: validate-prognostic-wrapper +description: Validate a newly created Earth2Studio prognostic model wrapper by running tests, performing reference comparison (single-step and multi-step), generating sanity-check plots, and opening a PR with automated code review. Use this skill after completing prognostic model implementation (create-prognostic-wrapper skill Steps 0-12). +argument-hint: Name of the prognostic model class and test file (optional — will be inferred from recent changes if not provided) +--- + +# Validate Prognostic Model Wrapper + +Validate a newly created Earth2Studio prognostic model (px) wrapper by +running tests, performing reference comparison (single-step and +multi-step), generating sanity-check plots, and opening a PR with +automated code review. This skill picks up after the +`create-prognostic-wrapper` skill completes implementation +(Steps 0-12). + +> **Python Environment:** This project uses **uv** for dependency +> management. Always use the local `.venv` virtual environment +> (`source .venv/bin/activate` or prefix with `uv run python`) for all +> Python commands — installing packages, running tests, executing +> scripts, etc. Use `uv add` / `uv pip install` / `uv lock` instead of +> `pip install`. + +Each confirmation gate marked by: + +```markdown +### **[CONFIRM — <Title>]** +``` + +requires **explicit user approval** before proceeding. + +--- + +## Step 1 — Run Tests + +### 1a. Run the new test file + +```bash +uv run python -m pytest test/models/px/test_<filename>.py -v --timeout=60 +``` + +All tests must pass. Fix failures and re-run until green. + +### 1b. Run coverage report with `--slow` tests + +Run the new test file **with coverage** and the `--slow` flag to +include integration tests. The new prognostic model file must achieve +**at least 90% line coverage**: + +```bash +uv run python -m pytest test/models/px/test_<filename>.py -v \ + --slow --timeout=300 \ + --cov=earth2studio/models/px/<filename> \ + --cov-report=term-missing \ + --cov-fail-under=90 +``` + +- `--slow` enables integration tests (marked `@pytest.mark.slow`) +- `--cov=earth2studio/models/px/<filename>` scopes coverage to the + new model module only +- `--cov-report=term-missing` shows which lines are not covered +- `--cov-fail-under=90` fails the run if coverage is below 90% + +If coverage is below 90%, add additional tests or mock tests to +cover the missing lines. Common coverage gaps for px models: + +- Error handling in `output_coords` (wrong variable names, wrong dims) +- Device management paths (CPU vs CUDA) +- `create_iterator` edge cases (initial condition yield, hook calls) +- `load_model` and `load_default_package` (needs mock or package test) +- ONNX / non-PyTorch backend `.to()` logic + +Re-run until coverage is at or above 90%. + +### 1c. Run the full model test suite (optional but recommended) + +```bash +make pytest TOX_ENV=test-models +``` + +Confirm no regressions in existing model tests. + +--- + +## Step 2 — Reference Comparison & Sanity-Check + +This step validates the prognostic model wrapper produces correct +output by comparing against the original reference implementation +for both a single time step and a multi-step forecast, then +generating visual sanity-check plots. + +### 2a. Create reference comparison script + +Create a **standalone Python script** in the repo root. This is for +validation only and should **NOT** be committed to the repo. + +The script loads the reference model and the E2S wrapper side by side, +runs both on identical input (same random seed or real data), and +compares outputs with tolerance. For prognostic models, test **both** +single-step (`__call__`) and multi-step (`create_iterator`): + +```python +"""Reference comparison for <ModelName> prognostic model. + +Compares the Earth2Studio wrapper output against the original reference +implementation to verify numerical agreement for both single-step and +multi-step forecasts. + +This script is for validation only — do NOT commit to the repo. +""" +import torch +import numpy as np + +# --- Reference model --- +# TODO: Load original model per reference repo instructions +# Uncomment and adapt the following lines: +# ref_model = ... +# ref_input = ... +# ref_output_single = ref_model(ref_input) # single step +# ref_outputs_multi = [ref_output_single] +# current = ref_output_single +# for step in range(N_STEPS): +# current = ref_model(current) +# ref_outputs_multi.append(current) +raise NotImplementedError( + "Fill in the reference model code above, then remove this line." +) + +# --- Earth2Studio wrapper --- +from earth2studio.models.px import ModelName + +model = ModelName(...) # or ModelName.load_model(package) +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +input_coords = model.input_coords() +# Construct input tensor matching the reference input +# Use the same random seed or identical real data for both +shape = tuple(max(len(v), 1) for v in input_coords.values()) +torch.manual_seed(42) +x = torch.randn(shape, device=device) + +# --- Single-step comparison --- +e2s_output_single, out_coords = model(x, input_coords) + +ref_output_single = ref_output_single.to(e2s_output_single.device) +max_abs_single = (ref_output_single - e2s_output_single).abs().max().item() +max_rel_single = ( + (ref_output_single - e2s_output_single).abs() + / (ref_output_single.abs() + 1e-8) +).max().item() +corr_single = torch.corrcoef( + torch.stack([ + ref_output_single.flatten(), + e2s_output_single.flatten(), + ]) +)[0, 1].item() + +print("=== Single-step comparison ===") +print(f"Max absolute difference: {max_abs_single:.2e}") +print(f"Max relative difference: {max_rel_single:.2e}") +print(f"Correlation: {corr_single:.8f}") + +assert torch.allclose( + ref_output_single, e2s_output_single, rtol=1e-4, atol=1e-5 +), f"Single-step mismatch! Max abs diff: {max_abs_single:.2e}" + +# --- Multi-step comparison --- +N_STEPS = 5 # Adapt to model time step (e.g., 5 steps of 6h = 30h) +iterator = model.create_iterator(x, input_coords) + +# Skip initial condition (step 0) +step0_x, step0_coords = next(iterator) + +print(f"\n=== Multi-step comparison ({N_STEPS} steps) ===") +for step_i in range(N_STEPS): + e2s_step, e2s_coords = next(iterator) + ref_step = ref_outputs_multi[step_i + 1].to(e2s_step.device) + + max_abs = (ref_step - e2s_step).abs().max().item() + corr = torch.corrcoef( + torch.stack([ref_step.flatten(), e2s_step.flatten()]) + )[0, 1].item() + lead = e2s_coords["lead_time"] + + print(f"Step {step_i + 1} (lead_time={lead}): " + f"max_abs={max_abs:.2e}, corr={corr:.8f}") + + assert torch.allclose(ref_step, e2s_step, rtol=1e-3, atol=1e-4), \ + f"Multi-step mismatch at step {step_i + 1}! Max abs: {max_abs:.2e}" + +print("\nPASS: Reference comparison successful (single + multi-step).") +``` + +### 2b. Summarize model capabilities to user + +Before generating sanity-check plots, **present a summary table** to +the user covering the model's capabilities: + +> **Model Summary for `<ClassName>`:** +> +> | Property | Value | +> |---|---| +> | **Input variables** | `var1`, `var2`, ... (N total) | +> | **Output variables** | `out1`, `out2`, ... (M total) | +> | **Time step** | Xh (e.g., 6h, 24h) | +> | **Spatial resolution** | X.XX deg x Y.YY deg (NxM) | +> | **History required** | None / Xh (e.g., -6h, 0h) | +> | **Checkpoint size** | XX MB | +> | **Checkpoint source** | NGC / HuggingFace / S3 | +> | **Inference time** | ~XX ms per step (on GPU/CPU) | + +This summary helps the user verify the wrapper matches their +expectations for the model. + +### 2c. Generate sanity-check plot script + +Create a **standalone Python script** in the repo root. This is for +PR reviewer reference only and should **NOT** be committed to the +repo. + +Choose the appropriate template based on the model's output type: + +#### Spatial forecast evolution (most common — gridded weather models) + +```python +"""Sanity-check plot for <ModelName> prognostic model. + +This script is for PR review only — do NOT commit to the repo. +Runs a multi-step forecast and visualizes the evolution of key +variables over lead time. +""" +import matplotlib.pyplot as plt +import numpy as np +import torch + +from earth2studio.data import Random, fetch_data +from earth2studio.models.px import ModelName + +# Load model +model = ModelName.from_pretrained() +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +# Prepare input +time = np.array([np.datetime64("2024-01-01T00:00")]) +input_coords = model.input_coords() +input_coords["time"] = time +ds = Random(input_coords) +x, coords = fetch_data(ds, time, input_coords["variable"], device=device) + +# Run multi-step forecast +N_STEPS = 5 +iterator = model.create_iterator(x, coords) + +# Collect outputs +steps = [] +for i, (step_x, step_coords) in enumerate(iterator): + steps.append((step_x.cpu().numpy(), dict(step_coords))) + if i >= N_STEPS: + break + +# Pick 2-3 representative variables +var_list = list(steps[0][1]["variable"]) +plot_vars = var_list[:3] # First 3 variables, or pick specific ones +n_vars = len(plot_vars) + +# Plot: rows = variables, columns = time steps (0, mid, final) +step_indices = [0, N_STEPS // 2, N_STEPS] +n_cols = len(step_indices) +fig, axes = plt.subplots(n_vars, n_cols, figsize=(5 * n_cols, 4 * n_vars)) +if n_vars == 1: + axes = axes[np.newaxis, :] +if n_cols == 1: + axes = axes[:, np.newaxis] + +for row, var in enumerate(plot_vars): + var_idx = var_list.index(var) + for col, si in enumerate(step_indices): + data, sc = steps[si] + # Shape: (batch, time, lead_time, variable, lat, lon) + data_2d = data[0, 0, 0, var_idx, :, :] + lead = sc["lead_time"] + im = axes[row, col].contourf(data_2d, cmap="turbo", levels=20) + axes[row, col].set_title(f"{var} | lead={lead}") + plt.colorbar(im, ax=axes[row, col], shrink=0.8) + +plt.suptitle(f"<ModelName> forecast evolution", y=1.02) +plt.tight_layout() +plt.savefig("sanity_check_<model_name>.png", dpi=150, bbox_inches="tight") +print("Saved: sanity_check_<model_name>.png") +``` + +#### Time-series at selected grid points + +For models where spatial patterns are less meaningful (e.g., random +input), or to supplement the spatial plot, show how variables evolve +over lead time at specific grid points: + +```python +"""Sanity-check time series for <ModelName> prognostic model. + +Shows how selected variables evolve over forecast lead time at +specific grid points. +""" +import matplotlib.pyplot as plt +import numpy as np +import torch + +from earth2studio.data import Random, fetch_data +from earth2studio.models.px import ModelName + +# Load model and run forecast +model = ModelName.from_pretrained() +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +time = np.array([np.datetime64("2024-01-01T00:00")]) +input_coords = model.input_coords() +input_coords["time"] = time +ds = Random(input_coords) +x, coords = fetch_data(ds, time, input_coords["variable"], device=device) + +N_STEPS = 10 +iterator = model.create_iterator(x, coords) + +# Collect time series at a central grid point +lat_idx = len(input_coords["lat"]) // 2 +lon_idx = len(input_coords["lon"]) // 2 +var_list = list(input_coords["variable"]) +plot_vars = var_list[:4] # Pick representative variables + +lead_times = [] +values = {v: [] for v in plot_vars} + +for i, (step_x, step_coords) in enumerate(iterator): + lead_times.append(step_coords["lead_time"][0]) + for var in plot_vars: + var_idx = var_list.index(var) + val = step_x[0, 0, 0, var_idx, lat_idx, lon_idx].cpu().item() + values[var].append(val) + if i >= N_STEPS: + break + +# Convert lead times to hours for plotting +lead_hours = [lt / np.timedelta64(1, "h") for lt in lead_times] + +fig, axes = plt.subplots(len(plot_vars), 1, + figsize=(10, 3 * len(plot_vars)), + sharex=True) +if len(plot_vars) == 1: + axes = [axes] + +for ax, var in zip(axes, plot_vars): + ax.plot(lead_hours, values[var], marker="o", markersize=3) + ax.set_ylabel(var) + ax.grid(True, alpha=0.3) + ax.set_title(f"{var} at lat_idx={lat_idx}, lon_idx={lon_idx}") + +axes[-1].set_xlabel("Lead time (hours)") +plt.suptitle(f"<ModelName> — time series at grid center", y=1.02) +plt.tight_layout() +plt.savefig("sanity_check_<model_name>_timeseries.png", + dpi=150, bbox_inches="tight") +print("Saved: sanity_check_<model_name>_timeseries.png") +``` + +#### Iterator behavior validation + +For verifying `create_iterator` mechanics (initial condition yield, +lead_time progression, hook application): + +```python +"""Iterator validation for <ModelName> prognostic model. + +Verifies create_iterator yields correct initial condition, increments +lead_time correctly, and that front/rear hooks are called. +""" +import numpy as np +import torch + +from earth2studio.data import Random, fetch_data +from earth2studio.models.px import ModelName + +model = ModelName.from_pretrained() +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +time = np.array([np.datetime64("2024-01-01T00:00")]) +input_coords = model.input_coords() +input_coords["time"] = time +ds = Random(input_coords) +x, coords = fetch_data(ds, time, input_coords["variable"], device=device) + +N_STEPS = 5 +iterator = model.create_iterator(x, coords) + +print("=== Iterator validation ===") +for i, (step_x, step_coords) in enumerate(iterator): + lead = step_coords["lead_time"] + print(f"Step {i}: shape={step_x.shape}, lead_time={lead}, " + f"device={step_x.device}") + + if i == 0: + # Step 0 must be the initial condition (unchanged input) + assert torch.allclose(step_x, x), \ + "Step 0 should yield the initial condition unchanged" + assert np.array_equal(lead, coords["lead_time"]), \ + "Step 0 lead_time should match input lead_time" + print(" -> Initial condition verified") + + if i >= N_STEPS: + break + +# Verify lead_time progression +expected_step = input_coords["lead_time"][0] +print(f"\nExpected time step: {expected_step}") +print(f"Final lead_time after {N_STEPS} steps: {step_coords['lead_time']}") +print("PASS: Iterator validation successful.") +``` + +### 2d. Run comparison and sanity-check scripts + +Execute all scripts: + +```bash +uv run python reference_comparison_<model_name>.py +uv run python sanity_check_<model_name>.py +``` + +Verify that: + +- The reference comparison passes for both single-step and multi-step +- The sanity-check script runs without errors +- Output PNGs are generated +- `create_iterator` yields correct initial condition at step 0 +- `lead_time` increments correctly at each step + +### 2e. **[CONFIRM — Sanity-Check & Comparison]** + +**You MUST ask the user to visually inspect the generated plot(s) +before proceeding.** Do not skip this step even if the scripts ran +without errors — a successful run does not guarantee the plots are +correct (e.g., empty axes, wrong colorbar range, garbled data). + +Tell the user the absolute path to the generated image file(s) and +the reference comparison metrics, then ask them to inspect: + +> The reference comparison and sanity-check scripts ran successfully. +> +> **Single-step reference comparison:** +> +> - Max absolute difference: `<value>` +> - Max relative difference: `<value>` +> - Correlation: `<value>` +> +> **Multi-step reference comparison (N steps):** +> +> - Step 1: max_abs=`<value>`, corr=`<value>` +> - Step N: max_abs=`<value>`, corr=`<value>` +> - (error may grow with lead time — this is expected) +> +> **Sanity-check plot saved to:** +> `/absolute/path/to/sanity_check_<model_name>.png` +> +> **Please open this image and confirm it looks correct.** Check: +> +> 1. Data is visible on the axes at all time steps (not blank/empty) +> 2. Values are in physically reasonable ranges +> 3. Spatial patterns evolve smoothly over lead time +> 4. No sudden jumps, NaN explosions, or frozen fields +> 5. Lead time labels increment correctly +> +> Does the plot look correct and do the reference comparison metrics +> look acceptable? + +**Do not proceed to Step 3 until the user explicitly confirms.** If +the user reports problems, debug and fix the issue, re-run the +scripts, and ask the user to inspect again. + +--- + +## Step 3 — Branch, Commit & Open PR + +### **[CONFIRM — Ready to Submit]** + +Before proceeding, confirm with the user: + +> All implementation and validation steps are complete: +> +> - Prognostic model class implemented with correct method ordering +> - Triple inheritance: `torch.nn.Module + AutoModelMixin + PrognosticMixin` +> - Coordinate system with proper `handshake_dim` indices +> - `__call__` (single step) and `create_iterator` (multi-step) implemented +> - `create_iterator` yields initial condition first, uses front/rear hooks +> - Model loading implemented (`load_default_package`, `load_model`) +> - Registered in `earth2studio/models/px/__init__.py` +> - Documentation added to `docs/modules/models.rst` +> - Reference URLs included in class docstrings +> - CHANGELOG.md updated +> - Format, lint, and license checks pass +> - Unit tests written and passing with >= 90% coverage +> - Dependencies in pyproject.toml confirmed +> - Reference comparison passes (single-step and multi-step) +> - Sanity-check plots generated and confirmed by user +> +> Ready to create a branch, commit, and prepare a PR? + +### 3a. Create branch and commit + +```bash +git checkout -b feat/prognostic-model-<name> +git add earth2studio/models/px/<filename>.py \ + earth2studio/models/px/__init__.py \ + test/models/px/test_<filename>.py \ + pyproject.toml \ + CHANGELOG.md \ + docs/modules/models.rst +git commit -m "feat: add <ClassName> prognostic model + +Add <ClassName> prognostic model for <brief description>. +Includes unit tests and documentation." +``` + +Do **NOT** add the sanity-check scripts, comparison scripts, or +their output images. + +### 3b. Identify the fork remote and push branch + +The working repository is typically a **fork** of +`NVIDIA/earth2studio`. Before pushing, confirm which git remote +points to the user's fork: + +```bash +git remote -v +``` + +Ask the user: + +> Which git remote is your fork of `NVIDIA/earth2studio`? +> (Usually `origin` — e.g., `git@github.com:<user>/earth2studio.git`) + +Then push the feature branch to the **fork** remote: + +```bash +git push -u <fork-remote> feat/prognostic-model-<name> +``` + +### 3c. Open Pull Request (fork -> NVIDIA/earth2studio) + +> **Important:** PRs must be opened **from the fork** to the +> **upstream source repository** `NVIDIA/earth2studio`. The branch +> lives on the fork; the PR targets `main` on the upstream repo. + +Use `gh pr create` with explicit `--repo` and `--head` flags: + +```bash +gh pr create \ + --repo NVIDIA/earth2studio \ + --base main \ + --head <fork-owner>:feat/prognostic-model-<name> \ + --title "feat: add <ClassName> prognostic model" \ + --body "..." +``` + +Where `<fork-owner>` is the GitHub username that owns the fork. + +The PR body should follow this prognostic-model-specific template: + +````markdown +## Description + +Add `<ClassName>` prognostic model for <brief description>. + +Closes #<issue_number> (if applicable) + +### Model details + +| Property | Value | +|---|---| +| **Architecture** | PyTorch / ONNX / JAX | +| **Time step** | Xh (e.g., 6h, 24h) | +| **Input variables** | N variables (list or link) | +| **Output variables** | M variables (list or link) | +| **Spatial resolution** | X° x Y° (NxM grid) | +| **History required** | None / Xh (e.g., [-6h, 0h]) | +| **Checkpoint source** | NGC / HuggingFace / S3 | +| **Checkpoint size** | XX MB | +| **Reference** | <paper/repo URL> | + +### Dependencies added + +| Package | Version | License | License URL | Reason | +|---|---|---|---|---| +| `<package>` | `>=X.Y` | <License> | [link](<URL>) | <reason> | + +*(or "No new dependencies needed")* + +When filling this table, look up each new dependency's license: + +1. Check the package's PyPI page or repository for the license type +2. Link directly to the license file +3. Flag any **non-permissive licenses** (GPL, AGPL, SSPL) — these + may be incompatible with the project's Apache-2.0 license + +### Reference comparison + +**Single step:** + +- Max absolute difference: <value> +- Max relative difference: <value> +- Correlation: <value> + +**Multi-step (N steps):** + +- Step 1: max_abs=<value>, corr=<value> +- Step N: max_abs=<value>, corr=<value> + +### Validation + +See sanity-check plots in PR comments below. + +## Checklist + +- [x] I am familiar with the [Contributing Guidelines][contrib]. +- [x] New or existing tests cover these changes. +- [x] The documentation is up to date with these changes. +- [x] The [CHANGELOG.md][changelog] is up to date with these changes. +- [ ] An [issue][issues] is linked to this pull request. +- [ ] Assess and address Greptile feedback (AI code review bot). + +[contrib]: https://github.com/NVIDIA/earth2studio/blob/main/CONTRIBUTING.md +[changelog]: https://github.com/NVIDIA/earth2studio/blob/main/CHANGELOG.md +[issues]: https://github.com/NVIDIA/earth2studio/issues +```` + +### 3d. Post sanity-check as PR comment + +After the PR is created, post the sanity-check visualization as a +separate **PR comment** so it is immediately visible to reviewers. + +#### Image upload limitation + +**GitHub has no CLI or REST API for uploading images to PR comments.** +The only way to embed an image is via the browser's drag-and-drop +editor or by referencing an already-hosted URL. + +**Practical workflow:** + +1. Write the comment body to a temp file (avoids shell quoting issues + with heredocs containing backticks and markdown). +2. Post the comment **without** the image — include the validation + table, reference comparison metrics, the full sanity-check script, + and a placeholder line. +3. Tell the user to drag the image into the browser editor. + +```bash +# 1. Write body to a temp file (use your editor tool, not heredoc) + +# 2. Post the comment +gh api -X POST repos/NVIDIA/earth2studio/issues/<PR_NUMBER>/comments \ + -F "body=@/tmp/pr_comment_body.md" \ + --jq '.html_url' +``` + +Do **not** waste time trying `curl` uploads, GraphQL file mutations, +or the `uploads.github.com` asset endpoint — they do not work for +issue/PR comment images. + +#### Comment content template + +```markdown +## Sanity-Check Validation + +**Model:** `<ClassName>` — <brief description> +**Architecture:** PyTorch / ONNX / JAX +**Time step:** Xh +**Test environment:** <GPU model or CPU> + +### Reference Comparison + +**Single step:** + +| Metric | Value | +|--------|-------| +| Max absolute difference | <value> | +| Max relative difference | <value> | +| Correlation | <value> | + +**Multi-step (N steps):** + +| Step | Lead time | Max abs diff | Correlation | +|------|-----------|-------------|-------------| +| 1 | Xh | <value> | <value> | +| ... | ... | ... | ... | +| N | NXh | <value> | <value> | + +### Model Summary + +| Property | Value | +|----------|-------| +| Input variables | <list or count> | +| Output variables | <list or count> | +| Spatial resolution | X° x Y° (NxM) | +| Time step | Xh | +| History required | None / [-Xh, 0h] | +| Inference time | ~XX ms per step | + +**Key findings:** + +- <bullet summarizing single-step numerical agreement> +- <bullet on multi-step error growth behavior> +- <bullet on spatial pattern quality over forecast horizon> + +> **TODO:** Attach sanity-check image by editing this comment in +> the browser. + +<details> +<summary>Sanity-check script (click to expand)</summary> + +```python +PASTE THE FULL WORKING SCRIPT HERE — not a truncated excerpt. +The script must be copy-pasteable and produce the plot end-to-end. +``` + +</details> +``` + +**Important:** Always paste the **complete, runnable** script — not +a shortened version. Reviewers should be able to reproduce the plot +by copying the script directly. + +#### Finalize + +After posting, inform the user of: + +1. The comment URL +2. The local path to the image file for manual attachment +3. Instructions: *"Edit the comment in your browser and drag the + image file into the editor to embed it."* + +> **Note:** The sanity-check image and script are for PR review +> purposes only — they must NOT be committed to the repository. + +--- + +## Step 4 — Automated Code Review (Greptile) + +After the PR is created and pushed, an automated code review from +**greptile-apps** (Greptile) will be posted as PR review comments. +Wait for this review, then process the feedback. + +### 4a. Wait for Greptile review + +Poll for review comments from `greptile-apps[bot]` every 30 seconds +for up to **5 minutes**. Time out gracefully if no review arrives: + +```bash +# Poll loop — check every 30s, timeout after 5 minutes (10 attempts) +for i in $(seq 1 10); do + REVIEW_ID=$(gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/reviews \ + --jq '.[] | select(.user.login == "greptile-apps[bot]") | .id' 2>/dev/null) + if [ -n "$REVIEW_ID" ]; then + echo "Greptile review found: $REVIEW_ID" + break + fi + echo "Attempt $i/10 — no review yet, waiting 30s..." + sleep 30 +done +``` + +If no review after 5 minutes, inform the user: + +> Greptile hasn't posted a review after 5 minutes. This can happen if +> the review bot is busy or the PR hasn't triggered it. You can: +> +> 1. Ask me to check again later +> 2. Skip this step and proceed without automated review +> 3. Manually request a review from Greptile on the PR page + +### 4b. Pull and parse review comments + +Once the review is posted, fetch all comments: + +```bash +# Get all review comments on the PR +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/comments \ + --jq '.[] | select(.user.login == "greptile-apps[bot]") | + {path: .path, line: .diff_hunk, body: .body}' +``` + +Also fetch the top-level review body: + +```bash +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/reviews \ + --jq '.[] | select(.user.login == "greptile-apps[bot]") | .body' +``` + +### 4c. Categorize and present to user + +Parse each comment and categorize it: + +| Category | Description | Default action | +|---|---|---| +| **Bug / correctness** | Logic errors, wrong behavior | Fix | +| **Style / convention** | Naming, formatting, patterns | Fix if valid | +| **Performance** | Inefficiency, resource waste | Evaluate | +| **Documentation** | Missing/wrong docs, docstrings | Fix | +| **Suggestion** | Alternative approach, nice-to-have | User decides | +| **False positive** | Incorrect or irrelevant feedback | Dismiss | + +### **[CONFIRM — Review Triage]** + +Present each comment to the user in a summary table: + +```markdown +| # | File | Line | Category | Summary | Proposed Action | +|---|------|------|----------|---------|-----------------| +| 1 | <model>.py | 142 | Bug | Missing null check | Fix: add guard | +| 2 | <model>.py | 305 | Style | Use f-string | Fix: convert | +| 3 | <model>.py | 45 | Suggestion | Add type alias | Skip: not needed | +| ... | ... | ... | ... | ... | ... | +``` + +For each comment, briefly explain: + +- What Greptile flagged +- Whether you agree or disagree (with reasoning) +- Your proposed fix (or why to skip) + +Ask the user to confirm which comments to address. The user may: + +- Accept all proposed fixes +- Select specific fixes +- Override your recommendation on any comment +- Add their own fixes + +### 4d. Implement fixes + +For each accepted fix: + +1. Make the code change +2. Run `make format && make lint` after all fixes +3. Run the relevant tests: + + ```bash + uv run python -m pytest test/models/px/test_<filename>.py -v --timeout=60 + ``` + +4. Commit with a message like: + + ```bash + git commit -m "fix: address code review feedback (Greptile)" + ``` + +### 4e. Respond to review comments + +For each Greptile comment, post a reply on the PR: + +**For fixed comments:** + +```bash +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/comments/<COMMENT_ID>/replies \ + -f body="Fixed in <commit_sha>. <brief description of fix>" +``` + +**For dismissed comments (false positives / won't fix):** + +```bash +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/comments/<COMMENT_ID>/replies \ + -f body="Won't fix — <brief justification>" +``` + +### 4f. Push and resolve + +```bash +git push <fork-remote> feat/prognostic-model-<name> +``` + +After pushing, resolve all addressed review threads if possible. + +Inform the user of the final state: + +- How many comments were fixed +- How many were dismissed (with reasons) +- Any remaining open threads + +--- + +## Reminders + +- **DO** use the repo's local `uv` `.venv` to run Python with + `uv run python` +- **DO NOT** commit sanity-check/comparison scripts or images to + the repo +- **DO** use `loguru.logger` for logging, never `print()`, inside + `earth2studio/` +- **DO** ensure all public functions have full type hints (mypy-clean) +- **DO** maintain alphabetical order in `__init__.py` exports, + RST file entries, and CHANGELOG entries +- **DO** follow the canonical PX method ordering: + `__init__`, `input_coords`, `output_coords`, + `load_default_package`, `load_model`, `to`, private methods, + `__call__`, `_default_generator`, `create_iterator` +- **DO** use `handshake_dim` indices matching each dimension's + position in the `CoordSystem` OrderedDict — check existing px + models for the predominant convention +- **DO** include reference URLs in class docstrings +- **DO** always update CHANGELOG.md under the current unreleased + version +- **DO** add the model to `docs/modules/models.rst` in the + `earth2studio.models.px` section (alphabetical order) +- **DO** inherit from `torch.nn.Module + AutoModelMixin + PrognosticMixin` + (triple inheritance) +- **DO** yield initial condition first in `create_iterator` + (step 0 = unchanged input) +- **DO** use `self.front_hook()` and `self.rear_hook()` in + `create_iterator` for perturbation injection / post-processing +- **DO** include `lead_time` in `input_coords` starting at + `np.timedelta64(0, "h")` and increment it in `output_coords` +- **DO NOT** make a general base class with intent to reuse the + wrapper across models +- **DO NOT** over-populate the `load_model()` API — only expose + essential parameters +- **NEVER** commit, hardcode, or include API keys, secrets, tokens, + or credentials in source code, sample scripts, commit messages, + PR descriptions, or any file tracked by git From ba274abc87df258fa95b5bf41e854740ab536b38 Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <ngeneva@nvidia.com> Date: Mon, 13 Apr 2026 22:53:11 +0000 Subject: [PATCH 3/7] feat: add create and validate assimilation model wrapper skills Add two new Claude skills for data assimilation (DA) models: - create-assimilation-wrapper: Steps 0-8 for wrapping DA models with FrameSchema/CoordSystem, send-protocol generators, cupy/cudf support - validate-assimilation-wrapper: Steps 1-5 for tests, reference comparison, sanity-check plots, PR submission, Greptile review Includes spec and implementation plan documents. --- .../create-assimilation-wrapper/SKILL.md | 1187 +++++++++++++ .../validate-assimilation-wrapper/SKILL.md | 1553 +++++++++++++++++ .../2026-04-13-assimilation-model-skills.md | 582 ++++++ ...-04-13-assimilation-model-skills-design.md | 311 ++++ 4 files changed, 3633 insertions(+) create mode 100644 .claude/skills/create-assimilation-wrapper/SKILL.md create mode 100644 .claude/skills/validate-assimilation-wrapper/SKILL.md create mode 100644 docs/superpowers/plans/2026-04-13-assimilation-model-skills.md create mode 100644 docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md diff --git a/.claude/skills/create-assimilation-wrapper/SKILL.md b/.claude/skills/create-assimilation-wrapper/SKILL.md new file mode 100644 index 000000000..fd2b7a30e --- /dev/null +++ b/.claude/skills/create-assimilation-wrapper/SKILL.md @@ -0,0 +1,1187 @@ +--- +name: create-assimilation-wrapper +description: >- + Create a new Earth2Studio data assimilation model (da) wrapper from a + reference inference script or repository. DA models ingest sparse + observations (DataFrames) and/or gridded state arrays (DataArrays) and + produce analysis output — they do NOT use @batch_func, @batch_coords, + or PrognosticMixin. +argument-hint: URL or local path to reference inference script/repo (optional — will be asked if not provided) +--- + +# Create Assimilation Model Wrapper + +Create a new Earth2Studio data assimilation model wrapper by following every step below in order. +Each confirmation gate marked by starting with: + +```markdown +### **[CONFIRM — <Title>]** +``` + +requires explicit user approval before proceeding. + +> **Environment note**: Use `uv run python` for all Python +> commands. The project uses a `uv`-managed virtual environment +> — never install packages globally or use bare `python`. + +--- + +## Critical Differences from px/dx Models + +DA models are **not** tensor-in / tensor-out like px and dx models. The entire I/O +contract is different. Review this table **before** writing any code: + +| Aspect | px / dx | da | +| ------------------- | ------------------------------ | ---------------------------------------- | +| **Primary input** | `Tensor` + `CoordSystem` | `pd.DataFrame` or `xr.DataArray` | +| **Primary output** | `Tensor` + `CoordSystem` | `xr.DataArray` or `pd.DataFrame` | +| **Coord schemas** | `CoordSystem` only | `FrameSchema` + `CoordSystem` | +| **Coord returns** | Single dict | **Tuple** (even for single) | +| **Batch handling** | `@batch_func` / `@batch_coords` | **Neither** — N/A for DataFrame I/O | +| **PrognosticMixin** | Yes (px) / No (dx) | **No** | +| **Time integration** | `create_iterator` yields | `create_generator` with **send** | +| **Generator prime** | Yields initial condition | Yields `None` / state, `.send()` | +| **Generator cleanup** | N/A | Must handle `GeneratorExit` | +| **Init data** | No | Optional via `init_coords()` | +| **Time metadata** | Coordinate dimension | `obs.attrs["request_time"]` | +| **GPU data** | `Tensor` on device | **cupy** arrays, **cudf** DFs | +| **Input validation** | `handshake_dim`/`handshake_coords` | `validate_observation_fields()` | +| **Time filtering** | N/A | `filter_time_range()` | +| **Tensor conversion** | Already tensors | `dfseries_to_torch()` | +| **Device tracking** | `device_buffer` / `parameters()` | `device_buffer` + `@property` | + +--- + +## Step 0 — Obtain Reference & Analyze Model + +### 0a. Obtain reference script / repository + +If `$ARGUMENTS` is provided, use it as the reference inference script or repository. + +- If it is a URL, use WebFetch to retrieve the content. +- If it is a local file path, read it directly. + +If `$ARGUMENTS` is empty or not provided, ask the user: + +> Please provide a reference inference script or +> repository URL/path that demonstrates how this model +> runs inference. This will be used to understand the +> model architecture, dependencies, input/output shapes, +> observation schema, and variable mapping. + +Store the reference code content for use in subsequent steps. + +### 0b. Analyze reference model + +After obtaining the reference, analyze it for: + +- **Input types**: Does the model ingest `pd.DataFrame` + (sparse observations), `xr.DataArray` (gridded fields), + or both? +- **Output types**: Does it produce `xr.DataArray` + (gridded analysis), `pd.DataFrame` (corrected + observations), or both? +- **Stateful vs stateless**: Does the model maintain + internal state across time steps (e.g., background field), + or is each call independent? Stateless models return + `None` from `init_coords()`. +- **`@torch.inference_mode()` safety**: Does the forward + pass require gradients (e.g., DPS guidance through a + denoiser)? If so, `@torch.inference_mode()` must be + omitted and the reason documented. +- **Dependencies**: External packages required (physicsnemo, scipy, healpy, cudf, cupy, etc.) + +Present the analysis to the user: + +> **Model Analysis Summary** +> +> - **Input type(s):** [DataFrame / DataArray / both] +> - **Output type(s):** [DataArray / DataFrame / both] +> - **Stateful/Stateless:** [stateful — needs init data / stateless — no init data] +> - **Inference mode safe:** [yes / no — reason] +> - **Key dependencies:** [list] +> +> Evidence: +> +> - [list key indicators from reference code] + +### **[CONFIRM — Model Analysis]** + +Ask the user to confirm the analysis before proceeding. + +--- + +## Step 1 — Examine Reference & Propose Dependencies + +### 1a. Analyze the reference code + +Examine the reference inference script/repo to identify: + +- **Python packages** required (e.g., `physicsnemo`, `scipy`, `healpy`, `cudf`, `cupy`, custom packages) +- **Model architecture** (PyTorch module, ONNX, etc.) +- **Observation schema** (DataFrame columns, variable names, coordinate dimensions) +- **Output grid specification** (lat/lon resolution, projection, etc.) +- **Checkpoint format** (`.pt`, `.onnx`, `.safetensors`, `.mdlus`, etc.) + +### 1b. Propose pyproject.toml dependency group + +Propose a new optional dependency group for `pyproject.toml`. +The group name must follow the pattern `da-<model-name>`: + +```toml +# In [project.optional-dependencies] section of pyproject.toml +da-model-name = [ + "package1>=version", + "package2", +] +``` + +Look at the existing groups in `pyproject.toml` for reference on naming and version pinning conventions. + +Also propose adding the new group to the `all` aggregate in the appropriate line (da models). + +Highlight `cudf` and `cupy` as optional GPU acceleration +packages — these are not required but improve performance. + +### **[CONFIRM — Dependencies]** + +Present to the user: + +1. The proposed dependency group name (`da-<model-name>`) +2. The list of packages with versions +3. Ask if the packages and group name look correct + +--- + +## Step 2 — Add Dependencies to pyproject.toml + +After confirmation, edit `pyproject.toml`: + +1. Add the new optional dependency group in alphabetical order among the per-model extras +2. Add the group to the `all` aggregate (in the da models line) + +--- + +## Step 3 — Create Skeleton Class File + +### 3a. Determine class name and file name + +Based on the model name from the reference, propose: + +- **Class name**: PascalCase (e.g., `StormCastSDA`, `InterpEquirectangular`, `HealDA`) +- **File name**: lowercase with underscores (e.g., `sda_stormcast.py`, `interp.py`, `healda.py`) +- **File path**: `earth2studio/models/da/<filename>.py` + +### 3b. Write skeleton with pseudocode + +Create the file with the full structure but pseudocode implementations. +Every `.py` file in `earth2studio/` **must** start with this license +header: + +```python +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +``` + +Dual inheritance: `torch.nn.Module + AutoModelMixin` (NO `PrognosticMixin`). +Class-level `@check_optional_dependencies()` decorator. + +### Canonical method ordering for DA models + +Methods in the class **must** appear in this order (11 method slots): + +1. `__init__` — register `device_buffer`, store model params, normalize time tolerance +2. `device` property — `return self.device_buffer.device` +3. `init_coords` — `None` for stateless, tuple of `CoordSystem`/`FrameSchema` for stateful +4. `input_coords` — tuple of `FrameSchema` (for DataFrame) or `CoordSystem` (for DataArray) +5. `output_coords` — accept `input_coords` tuple + `request_time` kwarg, return tuple +6. `load_default_package` — classmethod returning default `Package` +7. `load_model` — classmethod with `@check_optional_dependencies()` +8. `to` — device management, return `AssimilationModel` +9. Private/support methods (e.g., `_interpolate`, `_forward`, spatial lookups) +10. `__call__` — stateless forward, accept `*args: pd.DataFrame | xr.DataArray | None` +11. `create_generator` — bidirectional generator with send protocol + +### Complete skeleton template + +```python +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Generator +from typing import Any + +import numpy as np +import pandas as pd +import torch +import xarray as xr +from loguru import logger + +from earth2studio.models.auto import AutoModelMixin, Package +from earth2studio.models.da.base import AssimilationModel +from earth2studio.models.da.utils import ( + dfseries_to_torch, + filter_time_range, + validate_observation_fields, +) +from earth2studio.utils.imports import ( + OptionalDependencyFailure, + check_optional_dependencies, +) +from earth2studio.utils.time import normalize_time_tolerance +from earth2studio.utils.type import CoordSystem, FrameSchema, TimeTolerance + +try: + import cupy as cp +except ImportError: + cp = None # type: ignore[assignment] + +try: + import cudf +except ImportError: + cudf = None # type: ignore[assignment, misc] + +try: + from some_package import CoreModel +except ImportError: + OptionalDependencyFailure("da-mymodel") + CoreModel = None + + +@check_optional_dependencies() +class MyDAModel(torch.nn.Module, AutoModelMixin): + """One-line description of the DA model. + + Extended description of the model, its source, observation types it + handles, and the analysis output it produces. + + Parameters + ---------- + model : torch.nn.Module + Core neural network or inference model + time_tolerance : TimeTolerance, optional + Observation time tolerance window, by default np.timedelta64(3, "h") + + Note + ---- + For more information see: <link to paper/repo> + + Badges + ------ + region:global class:da product:atmos product:insitu + """ + + OUTPUT_VARIABLES = ["u10m", "v10m", "t2m"] + + # 1. Constructor + def __init__( + self, + model: torch.nn.Module, + time_tolerance: TimeTolerance = np.timedelta64(3, "h"), + ) -> None: + super().__init__() + self._model = model + self._tolerance = normalize_time_tolerance(time_tolerance) + self.register_buffer("device_buffer", torch.empty(0)) + # TODO: Register normalization buffers, static fields, etc. + + # 2. Device property + @property + def device(self) -> torch.device: + """Model device.""" + return self.device_buffer.device + + # 3. Init coords + def init_coords(self) -> None: + """Initialization coordinate system. + + Returns None for stateless models. Override to return a tuple of + CoordSystem / FrameSchema for stateful models that require initial + state data. + + Returns + ------- + None + No initialization data required (stateless model) + """ + return None + + # For stateful models, return a tuple instead: + # return ( + # CoordSystem(OrderedDict({ + # "time": np.empty(0), + # "lead_time": np.array([np.timedelta64(0, "h")]), + # "variable": np.array(self.OUTPUT_VARIABLES), + # "lat": self._lat, + # "lon": self._lon, + # })), + # ) + + # 4. Input coords + def input_coords(self) -> tuple[FrameSchema]: + """Input coordinate system specifying required DataFrame fields. + + Returns + ------- + tuple[FrameSchema] + Tuple containing FrameSchema with field names as keys. + Use np.empty(0, dtype=...) for dynamic dimensions and + np.array([...]) for enumerated allowed values. + """ + return ( + FrameSchema({ + "time": np.empty(0, dtype="datetime64[ns]"), + "lat": np.empty(0, dtype=np.float32), + "lon": np.empty(0, dtype=np.float32), + "observation": np.empty(0, dtype=np.float32), + "variable": np.array(self.OUTPUT_VARIABLES, dtype=str), + }), + ) + + # 5. Output coords + def output_coords( + self, + input_coords: tuple[FrameSchema | CoordSystem, ...], + request_time: np.ndarray | None = None, + **kwargs: Any, + ) -> tuple[CoordSystem]: + """Output coordinate system. + + Parameters + ---------- + input_coords : tuple[FrameSchema | CoordSystem, ...] + Input coordinate system(s) + request_time : np.ndarray | None, optional + Analysis valid time(s) + + Returns + ------- + tuple[CoordSystem] + Output coordinate system(s) + """ + if request_time is None: + request_time = np.array([np.datetime64("NaT")], dtype="datetime64[ns]") + + # Extract variables from first input coord system + if len(input_coords) > 0 and "variable" in input_coords[0]: + variables = input_coords[0]["variable"] + else: + variables = np.array(self.OUTPUT_VARIABLES, dtype=str) + + return ( + CoordSystem(OrderedDict({ + "time": request_time, + "variable": variables, + "lat": np.linspace(90, -90, 181), + "lon": np.linspace(0, 360, 360, endpoint=False), + })), + ) + + # 6. Default package + @classmethod + def load_default_package(cls) -> Package: + """Default pre-trained model package. + + Returns + ------- + Package + Model package with default checkpoint location + """ + return Package( + "hf://nvidia/my-da-model@<commit-sha>", + cache_options={"same_names": True}, + ) + + # 7. Load model + @classmethod + @check_optional_dependencies() + def load_model( + cls, + package: Package, + time_tolerance: TimeTolerance = np.timedelta64(3, "h"), + ) -> AssimilationModel: + """Load assimilation model from package. + + Parameters + ---------- + package : Package + Package containing model checkpoint and statistics + time_tolerance : TimeTolerance, optional + Observation time tolerance window + + Returns + ------- + AssimilationModel + Loaded assimilation model + """ + # TODO: Load model from package + model = CoreModel.from_checkpoint(package.resolve("model.mdlus")) + model.eval() + + # Load normalization stats, static fields, etc. + stats = np.load(package.resolve("stats.npy")) + + return cls(model=model, time_tolerance=time_tolerance) + + # 8. Device management + def to(self, device: torch.device | str) -> AssimilationModel: + """Move model to device. + + Parameters + ---------- + device : torch.device | str + Target device + + Returns + ------- + AssimilationModel + Model on target device + """ + super().to(device) + return self + + # 9. Private/support methods + def _forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Internal forward pass. + + Parameters + ---------- + inputs : torch.Tensor + Preprocessed input tensor on device + + Returns + ------- + torch.Tensor + Raw model output tensor + """ + # TODO: normalize -> model -> denormalize + return self._model(inputs) + + # 10. Stateless forward pass + @torch.inference_mode() + def __call__(self, obs: pd.DataFrame | None = None) -> xr.DataArray: + """Run single-step assimilation. + + Parameters + ---------- + obs : pd.DataFrame | None, optional + Observation DataFrame with required columns and + 'request_time' in attrs + + Returns + ------- + xr.DataArray + Analysis output on the same device as the model + """ + if obs is None: + raise ValueError("obs must be provided") + + # Validate required columns + input_coords = self.input_coords() + validate_observation_fields(obs, required_fields=list(input_coords[0].keys())) + + # Extract request_time from DataFrame attrs + request_time = obs.attrs.get("request_time") + if request_time is None: + raise ValueError( + "Observation DataFrame must have 'request_time' in attrs. " + "This is typically set by earth2studio data sources." + ) + + # Get output coordinate system + (output_coords,) = self.output_coords( + input_coords, request_time=request_time + ) + + # Filter observations within time tolerance window + filtered_obs = filter_time_range( + obs, + request_time=request_time, + tolerance=self._tolerance, + time_column="time", + ) + + # Convert DataFrame columns to torch tensors (zero-copy for cudf) + obs_values = dfseries_to_torch( + filtered_obs["observation"], + dtype=torch.float32, + device=self.device, + ) + obs_lat = dfseries_to_torch( + filtered_obs["lat"], + dtype=torch.float32, + device=self.device, + ) + obs_lon = dfseries_to_torch( + filtered_obs["lon"], + dtype=torch.float32, + device=self.device, + ) + + # TODO: Run model forward pass + prediction = self._forward(obs_values) + + # Build output xr.DataArray with cupy (GPU) or numpy (CPU) + if self.device.type == "cuda" and cp is not None: + data = cp.asarray(prediction) + else: + data = prediction.cpu().numpy() + + return xr.DataArray( + data=data, + dims=list(output_coords.keys()), + coords=output_coords, + ) + + # 11. Stateful generator + def create_generator( + self, + ) -> Generator[xr.DataArray, pd.DataFrame | None, None]: + """Creates a generator for sequential data assimilation. + + Yields the current analysis state and receives the next observations + via generator.send(). Prime the generator with next(gen) before + sending observations. + + Yields + ------ + xr.DataArray + Analysis at each step + + Receives + -------- + pd.DataFrame | None + Observations for the next step. Pass None for steps with no data. + + Example + ------- + >>> gen = model.create_generator() + >>> next(gen) # prime generator (yields None) + >>> result = gen.send(obs_df) # step 1 with observations + >>> result = gen.send(None) # step 2 without observations + >>> gen.close() # clean up + """ + # Prime generator — yield None for stateless models + observations = yield None # type: ignore[misc] + try: + while True: + result = self.__call__(observations) + observations = yield result + except GeneratorExit: + logger.debug("MyDAModel generator clean up complete.") +``` + +For **stateful models** that maintain internal state, the `create_generator` pattern is: + +```python +def create_generator( + self, + x: xr.DataArray, # initial state +) -> Generator[xr.DataArray, pd.DataFrame | None, None]: + """Creates a stateful generator for sequential data assimilation. + + Parameters + ---------- + x : xr.DataArray + Initial state DataArray + + Yields + ------ + xr.DataArray + Analysis at each step + + Receives + -------- + pd.DataFrame | None + Observations for the next step + + Example + ------- + >>> gen = model.create_generator(x0) + >>> state = next(gen) # yields initial state + >>> state = gen.send(obs_df) # step 1 with observations + >>> state = gen.send(None) # step 2 without observations + >>> gen.close() # clean up + """ + # Yield initial state to prime the generator + obs = yield x + + try: + while True: + result = self.__call__(x, obs) + x = result # update internal state + obs = yield result + except GeneratorExit: + logger.debug("MyDAModel generator clean up complete.") +``` + +### **[CONFIRM — Skeleton]** + +Present to the user: + +1. The proposed class name +2. The proposed file name and path +3. The canonical method ordering (11 methods) +4. Whether stateful or stateless generator pattern applies +5. Ask if these are acceptable + +--- + +## Step 4 — Implement Coordinate System + +### 4a. Map variables to E2STUDIO_VOCAB + +Read `earth2studio/lexicon/base.py` and verify every variable the +model uses exists in `E2STUDIO_VOCAB`. The vocab contains 282 +entries including: + +| Category | Examples | +| --------------- | ------------------------------------------- | +| Surface wind | `u10m`, `v10m`, `ws10m`, `u100m`, `v100m` | +| Surface temp | `t2m`, `d2m`, `sst`, `skt` | +| Humidity | `r2m`, `q2m`, `tcwv` | +| Pressure | `sp`, `msl` | +| Precipitation | `tp`, `lsp`, `cp`, `tp06` | +| Pressure-level | `u50`-`u1000`, `v50`-`v1000`, `z50`-`z1000` | +| Cloud/radiation | `tcc`, `rlut`, `rsut` | + +If a variable in the reference model does NOT exist in +`E2STUDIO_VOCAB`, flag it to the user and discuss whether to +map it to an existing vocab entry or propose adding a new one. + +### 4b. Implement init_coords + +**Stateless models** — return `None`: + +```python +def init_coords(self) -> None: + """No initialization data required.""" + return None +``` + +**Stateful models** — return a tuple of `CoordSystem` and/or `FrameSchema`: + +```python +def init_coords(self) -> tuple[CoordSystem]: + """Initialization coordinate system for the background state. + + Returns + ------- + tuple[CoordSystem] + Tuple containing one CoordSystem for the initial gridded state + """ + return ( + CoordSystem(OrderedDict({ + "time": np.empty(0), + "lead_time": np.array([np.timedelta64(0, "h")]), + "variable": np.array(self.OUTPUT_VARIABLES), + "lat": self._lat, + "lon": self._lon, + })), + ) +``` + +### 4c. Implement input_coords + +Return a **tuple** of `FrameSchema | CoordSystem` — one entry per +positional argument that `__call__` accepts. + +For DataFrame observations, use `FrameSchema`: + +```python +def input_coords(self) -> tuple[FrameSchema]: + """Input coordinate system specifying required DataFrame fields. + + Returns + ------- + tuple[FrameSchema] + Tuple containing FrameSchema for observation DataFrame + """ + return ( + FrameSchema({ + "time": np.empty(0, dtype="datetime64[ns]"), + "lat": np.empty(0, dtype=np.float32), + "lon": np.empty(0, dtype=np.float32), + "observation": np.empty(0, dtype=np.float32), + "variable": np.array(self.OUTPUT_VARIABLES, dtype=str), + }), + ) +``` + +For models accepting multiple inputs (e.g., conventional + satellite observations): + +```python +def input_coords(self) -> tuple[FrameSchema, FrameSchema]: + conv_schema = FrameSchema({...}) + sat_schema = FrameSchema({...}) + return (conv_schema, sat_schema) +``` + +**Rules:** + +- Use `np.empty(0, dtype=...)` for unbounded/dynamic dimensions (individual observations) +- Use `np.array([...])` for enumerated allowed values (variable names) +- Always return a **tuple**, even for a single input + +### 4d. Implement output_coords + +Accept the input coordinate tuple plus `request_time` kwargs. +Return a tuple of output `CoordSystem | FrameSchema`: + +```python +def output_coords( + self, + input_coords: tuple[FrameSchema | CoordSystem, ...], + request_time: np.ndarray | None = None, + **kwargs: Any, +) -> tuple[CoordSystem]: + """Output coordinate system. + + Parameters + ---------- + input_coords : tuple[FrameSchema | CoordSystem, ...] + Input coordinate system(s) + request_time : np.ndarray | None, optional + Analysis valid time(s) + + Returns + ------- + tuple[CoordSystem] + Output coordinate system(s) + """ + if request_time is None: + request_time = np.array([np.datetime64("NaT")], dtype="datetime64[ns]") + + # Extract variables from first input coord system + if len(input_coords) > 0 and "variable" in input_coords[0]: + variables = input_coords[0]["variable"] + else: + variables = np.array(self.OUTPUT_VARIABLES, dtype=str) + + return ( + CoordSystem(OrderedDict({ + "time": request_time, + "variable": variables, + "lat": self._lat, + "lon": self._lon, + })), + ) +``` + +**Key points:** + +- Validate `CoordSystem` inputs using `handshake_dim` / `handshake_size` / `handshake_coords` +- Validate `FrameSchema` inputs using `validate_observation_fields()` +- Always return a **tuple**, even for a single output + +### **[CONFIRM — Coordinates]** + +Present to the user: + +1. The input and output variable lists and any mapping issues with `E2STUDIO_VOCAB` +2. Whether `init_coords` returns `None` (stateless) or a tuple (stateful) +3. The `FrameSchema` column names and types for observation inputs +4. The output grid specification (lat/lon dimensions) + +--- + +## Step 5 — Implement Forward Pass + +### 5a. Implement `__call__` + +The stateless forward pass accepts typed arguments matching +`input_coords`, runs inference, and returns a tuple of +`xr.DataArray | pd.DataFrame`. + +```python +@torch.inference_mode() +def __call__(self, obs: pd.DataFrame | None = None) -> xr.DataArray: + """Run single-step assimilation. + + Parameters + ---------- + obs : pd.DataFrame | None, optional + Observation DataFrame + + Returns + ------- + xr.DataArray + Analysis output on the same device as the model + """ + if obs is None: + raise ValueError("obs must be provided") + + # 1. Validate required columns + input_coords = self.input_coords() + validate_observation_fields(obs, required_fields=list(input_coords[0].keys())) + + # 2. Extract request_time from DataFrame attrs + request_time = obs.attrs.get("request_time") + if request_time is None: + raise ValueError( + "Observation DataFrame must have 'request_time' in attrs. " + "This is typically set by earth2studio data sources." + ) + + # 3. Get output coordinate system + (output_coords,) = self.output_coords(input_coords, request_time=request_time) + + # 4. Filter observations within time tolerance window + filtered_obs = filter_time_range( + obs, + request_time=request_time, + tolerance=self._tolerance, + time_column="time", + ) + + # 5. Convert DataFrame columns to torch tensors (zero-copy for cudf) + obs_values = dfseries_to_torch( + filtered_obs["observation"], dtype=torch.float32, device=self.device, + ) + obs_lat = dfseries_to_torch( + filtered_obs["lat"], dtype=torch.float32, device=self.device, + ) + obs_lon = dfseries_to_torch( + filtered_obs["lon"], dtype=torch.float32, device=self.device, + ) + + # 6. Run model forward pass + prediction = self._forward(obs_values) + + # 7. Build output xr.DataArray with cupy (GPU) or numpy (CPU) + if self.device.type == "cuda" and cp is not None: + data = cp.asarray(prediction) + else: + data = prediction.cpu().numpy() + + return xr.DataArray( + data=data, + dims=list(output_coords.keys()), + coords=output_coords, + ) +``` + +**Key points:** + +- Use `@torch.inference_mode()` on `__call__` or on the internal `_forward` method — place it + on `_forward` if preprocessing outside `_forward` needs gradients. Omit entirely if the + forward pass requires gradients (e.g., DPS guidance); document the reason if omitted +- Expect `request_time` in `obs.attrs` — validate early +- Use `validate_observation_fields()` to check required DataFrame columns +- Use `filter_time_range()` for time-window filtering +- Use `dfseries_to_torch()` for zero-copy cudf→torch conversion +- Output data must live on the same device as the model (cupy for GPU, numpy for CPU) +- Do **NOT** use `@batch_func()` — this is a px/dx convention only + +### 5b. Implement `create_generator` + +**Stateless pattern** (no initial state, delegates to `__call__`): + +```python +def create_generator( + self, +) -> Generator[xr.DataArray, pd.DataFrame | None, None]: + """Creates a generator for sequential data assimilation. + + Yields + ------ + xr.DataArray + Analysis at each step + + Receives + -------- + pd.DataFrame | None + Observations for the next step + """ + # Prime generator — yield None for stateless models + observations = yield None # type: ignore[misc] + try: + while True: + result = self.__call__(observations) + observations = yield result + except GeneratorExit: + logger.debug("MyDAModel generator clean up complete.") +``` + +**Stateful pattern** (maintains background state across steps): + +```python +def create_generator( + self, + x: xr.DataArray, # initial state +) -> Generator[xr.DataArray, pd.DataFrame | None, None]: + """Creates a stateful generator for sequential data assimilation. + + Parameters + ---------- + x : xr.DataArray + Initial state DataArray + + Yields + ------ + xr.DataArray + Analysis at each step + + Receives + -------- + pd.DataFrame | None + Observations for the next step + """ + # Yield initial state to prime the generator + obs = yield x + + try: + while True: + result = self.__call__(x, obs) + x = result # update internal state + obs = yield result + except GeneratorExit: + logger.debug("MyDAModel generator clean up complete.") +``` + +**Key points:** + +- Always yield first to prime the generator (`yield None` for stateless, `yield x` for stateful) +- Receive observations via `.send()`: `observations = yield result` +- **Always** handle `GeneratorExit` for clean-up logic (e.g., releasing GPU resources) +- Do **NOT** use `create_iterator` — that is a px convention only + +### **[CONFIRM — Forward Pass]** + +Show the user the implementation for `__call__` and `create_generator`. Ask: + +1. Does the computation logic look correct? +2. Is `@torch.inference_mode()` safe, or does the model need gradient flow? +3. Are there any special considerations (multiple observation types, custom preprocessing)? + +--- + +## Step 6 — Implement Model Loading + +### 6a. Implement load_default_package + +```python +@classmethod +def load_default_package(cls) -> Package: + """Default pre-trained model package. + + Returns + ------- + Package + Model package with default checkpoint location + """ + return Package( + "hf://nvidia/my-da-model@<commit-sha>", + cache_options={"same_names": True}, + ) +``` + +### 6b. Implement load_model + +```python +@classmethod +@check_optional_dependencies() +def load_model( + cls, + package: Package, + time_tolerance: TimeTolerance = np.timedelta64(3, "h"), +) -> AssimilationModel: + """Load assimilation model from package. + + Parameters + ---------- + package : Package + Package containing model checkpoint and statistics + time_tolerance : TimeTolerance, optional + Observation time tolerance window + + Returns + ------- + AssimilationModel + Loaded assimilation model + """ + model = CoreModel.from_checkpoint(package.resolve("model.mdlus")) + model.eval() + + # Load normalization stats, static fields, etc. + stats = np.load(package.resolve("stats.npy")) + + return cls(model=model, time_tolerance=time_tolerance) +``` + +**Key patterns:** + +- Decorate `load_model` with `@check_optional_dependencies()` +- Use `package.resolve("filename")` to get cached file paths +- Call `.eval()` on loaded neural network modules +- Only expose essential parameters — do **not** over-populate the API + +### 6c. Implement .to() + +> **Note:** When the wrapper inherits from `torch.nn.Module`, +> `super().to(device)` already handles moving all registered +> parameters, buffers, and sub-modules. A custom `to()` +> override is only needed when there is non-PyTorch state to +> manage (e.g., ONNX Runtime sessions, JAX device placement). + +```python +def to(self, device: torch.device | str) -> AssimilationModel: + """Move model to device. + + Parameters + ---------- + device : torch.device | str + Target device + + Returns + ------- + AssimilationModel + Model on target device + """ + super().to(device) + return self +``` + +### **[CONFIRM — Model Loading]** + +Present to the user: + +1. The checkpoint URL/path for `load_default_package` +2. The checkpoint file names and loading logic +3. Whether there are multiple checkpoint files +4. The `.to()` implementation + +--- + +## Step 7 — Register the Model + +### 7a. Add to `__init__.py` + +Edit `earth2studio/models/da/__init__.py`: + +- Add import in alphabetical order: + `from earth2studio.models.da.<filename> import <ClassName>` +- If an `__all__` list exists, add `<ClassName>` to it in alphabetical order + +### 7b. Verify pyproject.toml + +Confirm the dependency group was added in Step 2 and is included in the `all` aggregate. + +--- + +## Step 8 — Verify Style, Documentation, Format & Lint + +Before testing, verify the wrapper passes all code quality checks. + +### 8a. Run formatting + +```bash +make format +``` + +This runs `black` on the codebase. Fix any formatting issues in +the new wrapper file. + +### 8b. Run linting + +```bash +make lint +``` + +This runs `ruff` and `mypy`. Common issues to watch for: + +- Missing type annotations on public functions +- Unused imports +- Import ordering issues +- Type errors from incorrect return types or missing annotations + +Fix all errors before proceeding. + +### 8c. Check license headers + +```bash +make license +``` + +Verify that the wrapper file +(`earth2studio/models/da/<filename>.py`) has the correct SPDX +Apache-2.0 license header (2024-2026 copyright years). + +### 8d. Verify documentation + +Check that: + +- The class docstring follows NumPy-style formatting with + `Parameters`, `Note`, `Badges` sections +- All public methods have complete docstrings with + `Parameters`, `Returns`, `Raises` sections as applicable +- Type hints are present on all public method signatures +- The model is added to `docs/modules/models.rst` in the + `earth2studio.models.da` section (alphabetical order) + +If any checks fail, fix the issues and re-run until all pass cleanly. + +--- + +## Reminders + +- **DO** return tuples from `input_coords` and `output_coords`, + even for single inputs/outputs +- **DO** use `FrameSchema` for tabular DataFrame inputs and + `CoordSystem` for gridded outputs +- **DO** validate `request_time` from `obs.attrs` — it is set + by earth2studio data sources +- **DO** use `validate_observation_fields()` to check required + DataFrame columns early +- **DO** use `filter_time_range()` for time-window filtering + of observations +- **DO** use `dfseries_to_torch()` for zero-copy cudf to torch + column conversion +- **DO** prime `create_generator` with `yield None` (stateless) + or `yield initial_state` (stateful) before the loop +- **DO** handle `GeneratorExit` in `create_generator` for + clean-up +- **DO** register `device_buffer` and expose a `device` property +- **DO** return cupy arrays on GPU, numpy arrays on CPU for + `xr.DataArray` output +- **DO** use `loguru.logger` for logging, never `print()`, + inside `earth2studio/` +- **DO** ensure all public functions have full type hints +- **DO** run formatting (`make format`) and linting + (`make lint`) before finalizing +- **DO** use `@torch.inference_mode()` on `__call__` for + inference-only models +- **DO** set `eval()` on loaded NN models in `load_model` +- **DO** add the model to `docs/modules/models.rst` in the + `earth2studio.models.da` section (alphabetical order) +- **DO** use `uv run python` for all Python commands (never + bare `python`) +- **DO NOT** use `@batch_func()` or `@batch_coords()` — these + are px/dx conventions only +- **DO NOT** use `PrognosticMixin` — DA models do not need + iterator hooks +- **DO NOT** use `create_iterator` — DA uses + `create_generator` with send protocol +- **DO NOT** assume tensor inputs — inputs are + DataFrames/DataArrays +- **DO NOT** forget cudf/cupy optional import pattern +- **DO NOT** use `@torch.inference_mode()` if the forward pass + requires gradients (e.g., DPS guidance); document the reason + if omitted +- **DO NOT** attempt to make a general base class with intent + to reuse the wrapper across models +- **DO NOT** over-populate the `load_model()` API — only + expose essential parameters diff --git a/.claude/skills/validate-assimilation-wrapper/SKILL.md b/.claude/skills/validate-assimilation-wrapper/SKILL.md new file mode 100644 index 000000000..45f785ced --- /dev/null +++ b/.claude/skills/validate-assimilation-wrapper/SKILL.md @@ -0,0 +1,1553 @@ +--- +name: validate-assimilation-wrapper +description: >- + Validate a newly created Earth2Studio data assimilation model wrapper by + writing unit tests (90% coverage required), performing reference comparison + with DataFrame/DataArray outputs, generating sanity-check plots, and + opening a PR with automated code review. Use after completing + create-assimilation-wrapper Steps 0-8. +argument-hint: Name of the DA model class and test file (optional — will be inferred from recent changes if not provided) +--- + +# Validate Assimilation Model Wrapper + +Validate a newly created Earth2Studio data assimilation (DA) model +wrapper by writing unit tests, performing reference comparison, +generating sanity-check outputs, and opening a PR with automated code +review. This skill picks up after the `create-assimilation-wrapper` +skill completes implementation (Steps 0-8). + +> **Python Environment:** This project uses **uv** for dependency +> management. Always use the local `.venv` virtual environment +> (`source .venv/bin/activate` or prefix with `uv run python`) for all +> Python commands — installing packages, running tests, executing +> scripts, etc. Use `uv add` / `uv pip install` / `uv lock` instead of +> `pip install`. + +Each confirmation gate marked by: + +```markdown +### **[CONFIRM — <Title>]** +``` + +requires **explicit user approval** before proceeding. + +--- + +## Step 1 — Write Pytest Unit Tests + +Create a test file at `test/models/da/test_<filename>.py`. DA tests +are fundamentally different from px/dx tests — inputs are DataFrames +(not tensors), outputs are `xr.DataArray` (not tensor + CoordSystem +tuples), and the generator uses the send protocol (not an iterator). + +### 1a. PhooModelName dummy class + +Create a lightweight dummy model that mimics the DA model under test. +The dummy accepts `pd.DataFrame` input and returns `xr.DataArray` +output. This is used for unit tests that do not require the real +model checkpoint. + +```python +class PhooModelName(torch.nn.Module): + """Dummy DA model for testing.""" + + VARIABLES = ["t2m", "u10m"] + + def __init__(self) -> None: + super().__init__() + self.register_buffer("device_buffer", torch.empty(0)) + self._lat = np.linspace(25.0, 50.0, 11, dtype=np.float32) + self._lon = np.linspace(235.0, 295.0, 13, dtype=np.float32) + + @property + def device(self) -> torch.device: + return self.device_buffer.device + + def __call__(self, obs: pd.DataFrame) -> xr.DataArray: + request_time = obs.attrs["request_time"] + data = torch.randn( + len(request_time), + len(self.VARIABLES), + len(self._lat), + len(self._lon), + ) + # Return as xr.DataArray (numpy for CPU) + return xr.DataArray( + data=data.numpy(), + dims=["time", "variable", "lat", "lon"], + coords={ + "time": request_time, + "variable": np.array(self.VARIABLES, dtype=str), + "lat": self._lat, + "lon": self._lon, + }, + ) + + def create_generator(self): + observations = yield None # Prime + try: + while True: + result = self.__call__(observations) + observations = yield result + except GeneratorExit: + pass + + def to(self, device): + return super().to(device) +``` + +### 1b. Test fixtures + +Define these fixtures at the top of the test file: + +```python +import numpy as np +import pandas as pd +import pytest +import torch +import xarray as xr + +try: + import cudf +except ImportError: + cudf = None + +try: + import cupy as cp +except ImportError: + cp = None + +from earth2studio.models.da.<module> import ModelName + + +@pytest.fixture +def sample_observations_pandas(): + """Create sample pandas DataFrame observations for testing.""" + time1 = np.datetime64("2024-01-01T12:00:00") + time2 = np.datetime64("2024-01-01T13:00:00") + return pd.DataFrame( + { + "time": [ + time1, time1, time1, time1, + time2, time2, time2, time2, + ], + "lat": [30.0, 30.0, 40.0, 40.0, 30.0, 30.0, 40.0, 40.0], + "lon": [240.0, 250.0, 240.0, 250.0, 240.0, 250.0, 240.0, 250.0], + "observation": [10.0, 20.0, 25.0, 35.0, 30.0, 40.0, 35.0, 45.0], + "variable": [ + "t2m", "t2m", "u10m", "u10m", + "t2m", "t2m", "u10m", "u10m", + ], + } + ) + + +@pytest.fixture +def sample_observations_cudf(): + """Create sample cudf DataFrame observations for testing.""" + if cudf is None: + pytest.skip("cudf not available") + time1 = np.datetime64("2024-01-01T12:00:00") + time2 = np.datetime64("2024-01-01T13:00:00") + return cudf.DataFrame( + { + "time": [ + time1, time1, time1, time1, + time2, time2, time2, time2, + ], + "lat": [30.0, 30.0, 40.0, 40.0, 30.0, 30.0, 40.0, 40.0], + "lon": [240.0, 250.0, 240.0, 250.0, 240.0, 250.0, 240.0, 250.0], + "observation": [10.0, 20.0, 30.0, 40.0, 30.0, 40.0, 50.0, 60.0], + "variable": [ + "t2m", "t2m", "u10m", "u10m", + "t2m", "t2m", "u10m", "u10m", + ], + } + ) + + +@pytest.fixture +def test_package(tmp_path): + """Create a dummy checkpoint directory for integration tests.""" + # Adapt contents to match the model's load_model expectations + checkpoint_dir = tmp_path / "checkpoint" + checkpoint_dir.mkdir() + # Example: create a dummy weights file + torch.save({"state_dict": {}}, checkpoint_dir / "model.pt") + return checkpoint_dir +``` + +### 1c. Test methods + +Write the following test methods. Each test must be **complete and +runnable** — no placeholder stubs. + +> **Note:** Every test file must start with the SPDX Apache-2.0 +> license header (see the create skill for the exact template). + +#### test_model_init + +```python +def test_model_init(): + """Test constructor parameter validation.""" + model = ModelName(...) # Fill in required constructor args + + # Verify model attributes are set correctly + assert hasattr(model, "device_buffer") + assert hasattr(model, "device") + + # Test invalid constructor args raise errors + with pytest.raises((ValueError, TypeError)): + ModelName(invalid_param=...) +``` + +#### test_model_call + +```python +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="cuda missing" + ), + ), + ], +) +def test_model_call(sample_observations_pandas, device): + """Test stateless __call__ with pandas DataFrame input.""" + model = ModelName(...).to(device) + + request_time = np.array([np.datetime64("2024-01-01T12:00:00")]) + sample_observations_pandas.attrs = { + "request_time": request_time, + "request_lead_time": np.array([np.timedelta64(0, "h")]), + } + + da = model(sample_observations_pandas) + + assert isinstance(da, xr.DataArray) + assert da.dims == ("time", "variable", "lat", "lon") + assert da.coords["time"].values[0] == request_time[0] + + # Validate output shape matches model's coordinate system + n_variables = len(model.VARIABLES) if hasattr(model, "VARIABLES") else da.shape[1] + assert da.shape[0] == len(request_time) + assert da.shape[1] == n_variables + + # Validate coordinate values + assert "t2m" in da.coords["variable"].values # At least one expected variable + + # Check device-specific return type + if device == "cuda:0" and torch.cuda.is_available(): + if cp is not None: + assert isinstance(da.data, cp.ndarray) + assert not cp.all(cp.isnan(da.data)) + else: + assert isinstance(da.data, np.ndarray) + assert not np.all(np.isnan(da.values)) +``` + +#### test_model_call_cudf + +```python +@pytest.mark.parametrize( + "device", + [ + pytest.param( + "cuda:0", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="cuda missing" + ), + ), + ], +) +def test_model_call_cudf(sample_observations_cudf, device): + """Test __call__ with cudf DataFrame input on GPU.""" + if cudf is None: + pytest.skip("cudf not available") + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + model = ModelName(...).to(device) + + request_time = np.array([np.datetime64("2024-01-01T12:00:00")]) + sample_observations_cudf.attrs = { + "request_time": request_time, + "request_lead_time": np.array([np.timedelta64(0, "h")]), + } + + da = model(sample_observations_cudf) + + assert isinstance(da, xr.DataArray) + assert da.dims == ("time", "variable", "lat", "lon") + assert da.coords["time"].values[0] == request_time[0] + + if cp is not None: + assert isinstance(da.data, cp.ndarray) + assert not cp.all(cp.isnan(da.data)) + else: + assert isinstance(da.data, np.ndarray) + assert not np.all(np.isnan(da.values)) +``` + +#### test_generator_protocol + +```python +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="cuda missing" + ), + ), + ], +) +def test_generator_protocol(sample_observations_pandas, device): + """Test create_generator prime -> send -> close sequence.""" + model = ModelName(...).to(device) + + request_time = np.array([np.datetime64("2024-01-01T12:00:00")]) + sample_observations_pandas.attrs = { + "request_time": request_time, + "request_lead_time": np.array([np.timedelta64(0, "h")]), + } + + generator = model.create_generator() + + # Prime the generator + result = generator.send(None) + assert result is None + + # Send observations — first step + da = generator.send(sample_observations_pandas) + assert isinstance(da, xr.DataArray) + assert da.dims == ("time", "variable", "lat", "lon") + assert da.shape[0] == len(request_time) + + # Send observations — second step + da2 = generator.send(sample_observations_pandas) + assert isinstance(da2, xr.DataArray) + assert da2.shape == da.shape + + # Close generator + generator.close() +``` + +#### test_init_coords + +```python +def test_init_coords(): + """Test init_coords returns correct type for the model.""" + model = ModelName(...) + + result = model.init_coords() + + # Choose ONE of the following patterns based on the model: + + # For stateless models (no initialization data required): + assert result is None + + # For stateful models (requires initialization data): + # assert isinstance(result, tuple) + # assert len(result) > 0 + # for schema in result: + # assert isinstance(schema, (dict, OrderedDict)) +``` + +#### test_input_coords + +```python +def test_input_coords(): + """Test input_coords returns tuple of FrameSchema.""" + model = ModelName(...) + + result = model.input_coords() + + assert isinstance(result, tuple) + assert len(result) >= 1 + + # Each element should be a FrameSchema (OrderedDict) + for schema in result: + assert isinstance(schema, dict) + # DA observation schemas typically have these columns + assert "variable" in schema +``` + +#### test_output_coords + +```python +def test_output_coords(): + """Test output_coords returns valid tuple of CoordSystem.""" + model = ModelName(...) + + input_coords = model.input_coords() + request_time = np.array([np.datetime64("2024-01-01T12:00:00")]) + + result = model.output_coords(input_coords, request_time=request_time) + + assert isinstance(result, tuple) + assert len(result) >= 1 + + # Each output should be a CoordSystem with expected dimensions + for coords in result: + assert isinstance(coords, dict) + assert "time" in coords + assert "variable" in coords +``` + +#### test_time_tolerance + +```python +def test_time_tolerance(): + """Test filter_time_range behavior with time tolerance.""" + model = ModelName(...) + + base_time = np.datetime64("2024-01-01T12:00:00") + time_within = base_time - np.timedelta64(30, "m") + time_outside = base_time + np.timedelta64(24, "h") + + obs_df = pd.DataFrame( + { + "time": [time_within, time_outside], + "lat": [30.0, 40.0], + "lon": [240.0, 250.0], + "observation": [10.0, 20.0], + "variable": ["t2m", "t2m"], + } + ) + + request_time = np.array([base_time]) + obs_df.attrs = { + "request_time": request_time, + "request_lead_time": np.array([np.timedelta64(0, "h")]), + } + + da = model(obs_df) + assert isinstance(da, xr.DataArray) + # Verify output is valid — exact assertions depend on model behavior + assert da.shape[0] == len(request_time) +``` + +#### test_empty_dataframe + +```python +def test_empty_dataframe(): + """Test graceful handling of empty DataFrame.""" + model = ModelName(...) + + empty_df = pd.DataFrame( + { + "time": pd.Series([], dtype="datetime64[ns]"), + "lat": pd.Series([], dtype=np.float32), + "lon": pd.Series([], dtype=np.float32), + "observation": pd.Series([], dtype=np.float32), + "variable": pd.Series([], dtype=str), + } + ) + + request_time = np.array([np.datetime64("2024-01-01T12:00:00")]) + empty_df.attrs = { + "request_time": request_time, + "request_lead_time": np.array([np.timedelta64(0, "h")]), + } + + # Model should either return a DataArray (possibly with NaN) or raise cleanly + try: + da = model(empty_df) + assert isinstance(da, xr.DataArray) + except (ValueError, RuntimeError): + pass # Acceptable to raise on empty input +``` + +#### test_invalid_attrs + +```python +def test_invalid_attrs(): + """Test that missing request_time in attrs raises an error.""" + model = ModelName(...) + + obs_df = pd.DataFrame( + { + "time": [np.datetime64("2024-01-01T12:00:00")], + "lat": [30.0], + "lon": [240.0], + "observation": [10.0], + "variable": ["t2m"], + } + ) + # Intentionally do NOT set obs_df.attrs with request_time + + with pytest.raises((ValueError, KeyError, TypeError)): + model(obs_df) +``` + +#### test_validate_observation_fields + +```python +def test_validate_observation_fields(): + """Test that invalid DataFrame columns raise an error.""" + model = ModelName(...) + + bad_df = pd.DataFrame( + { + "wrong_column": [1.0], + "another_bad": [2.0], + } + ) + bad_df.attrs = { + "request_time": np.array([np.datetime64("2024-01-01T12:00:00")]), + } + + with pytest.raises((ValueError, KeyError)): + model(bad_df) +``` + +#### test_model_exceptions + +```python +def test_model_exceptions(): + """Test model raises on invalid inputs.""" + model = ModelName(...) + + # Test with None input (if model requires non-None) + with pytest.raises((ValueError, TypeError)): + model(None) +``` + +#### Integration test (@pytest.mark.package) + +```python +@pytest.mark.package +def test_model_package(test_package): + """Integration test using model checkpoint. + + This test requires the actual model package and is skipped by + default. Run with --slow to include. + """ + from earth2studio.models.auto import Package + + package = Package(str(test_package)) + model = ModelName.load_model(package) + + obs_df = pd.DataFrame( + { + "time": [np.datetime64("2024-01-01T12:00:00")], + "lat": [30.0], + "lon": [240.0], + "observation": [10.0], + "variable": ["t2m"], + } + ) + obs_df.attrs = { + "request_time": np.array([np.datetime64("2024-01-01T12:00:00")]), + "request_lead_time": np.array([np.timedelta64(0, "h")]), + } + + da = model(obs_df) + assert isinstance(da, xr.DataArray) +``` + +### **[CONFIRM — Package Test]** + +Before writing the `@pytest.mark.package` integration test, ask the +user to confirm the model loading path and checkpoint structure: + +> The integration test needs to load the actual model checkpoint. +> +> 1. What is the checkpoint path or package URL? +> 2. Does `load_model` require additional arguments beyond `package`? +> 3. What test inputs should the integration test use? +> +> I'll write the `@pytest.mark.package` test based on your answers. + +**Do not proceed to Step 2 until the user confirms.** + +--- + +## Step 2 — Run Tests & Achieve 90% Coverage + +### 2a. Run the new test file + +```bash +uv run python -m pytest test/models/da/test_<filename>.py -v --timeout=60 +``` + +All tests must pass. Fix failures and re-run until green. + +### 2b. Run coverage report with `--slow` tests + +Run the new test file **with coverage** and the `--slow` flag to +include integration tests. The new DA model file must achieve **at +least 90% line coverage**: + +```bash +uv run python -m pytest test/models/da/test_<filename>.py -v \ + --slow --timeout=300 \ + --cov=earth2studio/models/da/<filename> \ + --cov-report=term-missing \ + --cov-fail-under=90 +``` + +- `--slow` enables integration tests marked with `@pytest.mark.package` + (the `--slow` flag is configured in `conftest.py` to include package + tests that download real checkpoints and may require GPU) +- `--cov=earth2studio/models/da/<filename>` scopes coverage to the + new model module only +- `--cov-report=term-missing` shows which lines are not covered +- `--cov-fail-under=90` fails the run if coverage is below 90% + +If coverage is below 90%, add additional tests or mock tests to cover +the missing lines. Common DA-specific coverage gaps: + +- `GeneratorExit` cleanup path in `create_generator` +- cudf code paths (skipped when cudf is unavailable) +- Empty DataFrame handling branches +- Time tolerance edge cases (observations right at boundary) +- `obs.attrs` validation branches (missing `request_time`) +- cupy vs numpy output paths (GPU vs CPU return types) +- `filter_time_range` with no matching observations +- `dfseries_to_torch` conversion branches + +Re-run until coverage is at or above 90%. + +### 2c. Run the full model test suite (optional but recommended) + +```bash +make pytest TOX_ENV=test-models +``` + +Confirm no regressions in existing model tests. + +--- + +## Step 3 — Reference Comparison & Sanity-Check + +This step validates the DA model wrapper produces correct output by +comparing against the original reference implementation and generating +visual sanity-check plots. + +### 3a. Create reference comparison script + +Create a **standalone Python script** in the repo root. This is for +validation only and should **NOT** be committed to the repo. + +The script loads the reference model and the E2S wrapper side by side, +runs both on identical input, and compares outputs with tolerance. + +**For DataArray output** (gridded analysis fields): + +```python +"""Reference comparison for <ModelName> assimilation model. + +Compares the Earth2Studio wrapper output against the original reference +implementation to verify numerical agreement. + +This script is for validation only — do NOT commit to the repo. +""" +import numpy as np +import pandas as pd +import torch + +# --- Reference model --- +# TODO: Load original model per reference repo instructions +# Uncomment and adapt the following lines: +# ref_model = ... +# ref_obs = pd.DataFrame({...}) +# ref_obs.attrs = {"request_time": ..., "request_lead_time": ...} +# ref_output = ref_model(ref_obs) +raise NotImplementedError( + "Fill in the reference model code above, then remove this line." +) + +# --- Earth2Studio wrapper --- +from earth2studio.models.da import ModelName + +model = ModelName(...) # or ModelName.load_model(package) +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +# Construct identical observation DataFrame +obs_df = pd.DataFrame( + { + "time": [...], + "lat": [...], + "lon": [...], + "observation": [...], + "variable": [...], + } +) +obs_df.attrs = { + "request_time": np.array([...], dtype="datetime64[ns]"), + "request_lead_time": np.array([np.timedelta64(0, "h")]), +} + +e2s_output = model(obs_df) + +# --- Compare DataArray outputs --- +ref_values = ref_output.values # numpy array +e2s_values = e2s_output.values # numpy or cupy array +if hasattr(e2s_values, "get"): + e2s_values = e2s_values.get() # cupy -> numpy + +max_abs_diff = np.abs(ref_values - e2s_values).max() +max_rel_diff = ( + np.abs(ref_values - e2s_values) / (np.abs(ref_values) + 1e-8) +).max() +correlation = np.corrcoef(ref_values.flatten(), e2s_values.flatten())[0, 1] + +print(f"Max absolute difference: {max_abs_diff:.2e}") +print(f"Max relative difference: {max_rel_diff:.2e}") +print(f"Correlation: {correlation:.8f}") + +assert np.allclose(ref_values, e2s_values, rtol=1e-4, atol=1e-5), \ + f"Output mismatch! Max abs diff: {max_abs_diff:.2e}" + +# --- Compare generator outputs --- +ref_gen = ref_model.create_generator() +ref_gen.send(None) +e2s_gen = model.create_generator() +e2s_gen.send(None) + +for step_obs in [obs_df, obs_df]: + ref_step = ref_gen.send(step_obs) + e2s_step = e2s_gen.send(step_obs) + ref_step_vals = ref_step.values + e2s_step_vals = e2s_step.values + if hasattr(e2s_step_vals, "get"): + e2s_step_vals = e2s_step_vals.get() + step_diff = np.abs(ref_step_vals - e2s_step_vals).max() + print(f"Generator step max abs diff: {step_diff:.2e}") + assert np.allclose(ref_step_vals, e2s_step_vals, rtol=1e-4, atol=1e-5) + +ref_gen.close() +e2s_gen.close() + +print("PASS: Reference comparison successful.") +``` + +**For DataFrame output** (if the model returns tabular data): + +```python +# --- Compare DataFrame outputs --- +assert len(ref_output) == len(e2s_output), \ + f"Row count mismatch: ref={len(ref_output)}, e2s={len(e2s_output)}" + +for col in ["lat", "lon", "observation", "variable"]: + if col in ref_output.columns: + ref_vals = ref_output[col].values + e2s_vals = e2s_output[col].values + if np.issubdtype(ref_vals.dtype, np.floating): + max_diff = np.abs(ref_vals - e2s_vals).max() + print(f"Column '{col}' max diff: {max_diff:.2e}") + else: + assert np.array_equal(ref_vals, e2s_vals), \ + f"Column '{col}' values differ" + +# Spatial coverage check +ref_lat_range = (ref_output["lat"].min(), ref_output["lat"].max()) +e2s_lat_range = (e2s_output["lat"].min(), e2s_output["lat"].max()) +print(f"Lat range: ref={ref_lat_range}, e2s={e2s_lat_range}") + +ref_lon_range = (ref_output["lon"].min(), ref_output["lon"].max()) +e2s_lon_range = (e2s_output["lon"].min(), e2s_output["lon"].max()) +print(f"Lon range: ref={ref_lon_range}, e2s={e2s_lon_range}") +``` + +### 3b. Summarize model capabilities to user + +Before generating sanity-check plots, **present a summary table** to +the user covering the model's capabilities: + +> **Model Summary for `<ClassName>`:** +> +> | Property | Value | +> | ---------------------- | ------------------------------------ | +> | **Model type** | Stateless / Stateful | +> | **Input format** | DataFrame / DataArray / Mixed | +> | **Input schema** | time, lat, lon, observation, variable | +> | **Output format** | DataArray / DataFrame | +> | **Output grid** | lat-lon N x M / HRRR / HealPix | +> | **Output variables** | `var1`, `var2`, ... | +> | **Time tolerance** | (default value) | +> | **cudf/cupy support** | Yes / No | +> | **Observation types** | Surface / Satellite / Radar / Mixed | +> | **Checkpoint source** | NGC / HuggingFace / N/A | + +This summary helps the user verify the wrapper matches their +expectations for the model. + +### 3c. Generate sanity-check plot scripts + +Create **standalone Python scripts** in the repo root. These are for +PR reviewer reference only and should **NOT** be committed to the +repo. + +#### Plot 1: Spatial assimilated output + +Contourf of gridded DataArray output from `__call__`: + +```python +"""Sanity-check plot 1: Spatial assimilated output for <ModelName>. + +This script is for PR review only — do NOT commit to the repo. +""" +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch + +from earth2studio.models.da import ModelName + +# Load model +model = ModelName(...) # or ModelName.load_model(package) +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +# Create observation DataFrame +obs_df = pd.DataFrame( + { + "time": [np.datetime64("2024-01-01T12:00:00")] * 4, + "lat": [30.0, 35.0, 40.0, 45.0], + "lon": [240.0, 250.0, 260.0, 270.0], + "observation": [10.0, 15.0, 20.0, 25.0], + "variable": ["t2m", "t2m", "t2m", "t2m"], + } +) +obs_df.attrs = { + "request_time": np.array([np.datetime64("2024-01-01T12:00:00")]), + "request_lead_time": np.array([np.timedelta64(0, "h")]), +} + +# Run forward pass +da = model(obs_df) + +# Get coordinate arrays +lat = da.coords["lat"].values +lon = da.coords["lon"].values +variables = da.coords["variable"].values + +# Plot contourf for each output variable +n_vars = len(variables) +fig, axes = plt.subplots(1, n_vars, figsize=(6 * n_vars, 5)) +if n_vars == 1: + axes = [axes] + +for ax, var in zip(axes, variables): + data_2d = da.sel(variable=var).isel(time=0).values + if hasattr(data_2d, "get"): + data_2d = data_2d.get() # cupy -> numpy + im = ax.contourf(lon, lat, data_2d, cmap="turbo", levels=20) + ax.set_title(f"{var}") + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + plt.colorbar(im, ax=ax, shrink=0.8) + +plt.suptitle(f"<ModelName> assimilated output", y=1.02) +plt.tight_layout() +plt.savefig("sanity_check_da_spatial.png", dpi=150, bbox_inches="tight") +print("Saved: sanity_check_da_spatial.png") +``` + +#### Plot 2: Observation overlay (unique to DA) + +Scatter of input DataFrame observations overlaid on assimilated grid +output. This visualization is **specific to DA models** — it shows the +sparse-to-dense mapping from observations to analysis field. + +```python +"""Sanity-check plot 2: Observation overlay for <ModelName>. + +This script is for PR review only — do NOT commit to the repo. +Shows input sparse observations overlaid on the assimilated gridded output. +""" +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch + +from earth2studio.models.da import ModelName + +# Load model +model = ModelName(...) # or ModelName.load_model(package) +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +# Create observation DataFrame +obs_df = pd.DataFrame( + { + "time": [np.datetime64("2024-01-01T12:00:00")] * 8, + "lat": [30.0, 32.0, 35.0, 38.0, 40.0, 42.0, 45.0, 48.0], + "lon": [240.0, 245.0, 250.0, 255.0, 260.0, 265.0, 270.0, 275.0], + "observation": [10.0, 12.0, 15.0, 18.0, 20.0, 22.0, 25.0, 28.0], + "variable": ["t2m"] * 8, + } +) +obs_df.attrs = { + "request_time": np.array([np.datetime64("2024-01-01T12:00:00")]), + "request_lead_time": np.array([np.timedelta64(0, "h")]), +} + +# Run forward pass +da = model(obs_df) + +# Get coordinate arrays +lat = da.coords["lat"].values +lon = da.coords["lon"].values +variables = [v for v in da.coords["variable"].values if v in obs_df["variable"].unique()] + +n_vars = len(variables) +fig, axes = plt.subplots( + 1, n_vars, figsize=(8 * n_vars, 6), + subplot_kw={"projection": ccrs.PlateCarree()}, +) +if n_vars == 1: + axes = [axes] + +for ax, var in zip(axes, variables): + # Plot assimilated grid as contourf + data_2d = da.sel(variable=var).isel(time=0).values + if hasattr(data_2d, "get"): + data_2d = data_2d.get() # cupy -> numpy + ax.contourf( + lon, lat, data_2d, + cmap="turbo", alpha=0.5, levels=20, + transform=ccrs.PlateCarree(), + ) + + # Overlay input observations as scatter + obs_var = obs_df[obs_df["variable"] == var] + scatter = ax.scatter( + obs_var["lon"].values, obs_var["lat"].values, + c=obs_var["observation"].values, + cmap="turbo", edgecolors="k", s=40, zorder=5, + transform=ccrs.PlateCarree(), + ) + plt.colorbar(scatter, ax=ax, shrink=0.7, label="Observation value") + + ax.add_feature(cfeature.COASTLINE) + ax.add_feature(cfeature.BORDERS, linestyle=":") + ax.set_title(f"{var}: grid + observations") + +plt.suptitle("<ModelName> — observation overlay", y=1.02) +plt.tight_layout() +plt.savefig("sanity_check_da_overlay.png", dpi=150, bbox_inches="tight") +print("Saved: sanity_check_da_overlay.png") +``` + +#### Plot 3: Generator sequence + +Multi-step assimilation evolution using the generator protocol: + +```python +"""Sanity-check plot 3: Generator sequence for <ModelName>. + +This script is for PR review only — do NOT commit to the repo. +Shows evolution of assimilated output across multiple generator steps. +""" +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch + +from earth2studio.models.da import ModelName + +# Load model +model = ModelName(...) # or ModelName.load_model(package) +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +# Create observation sequence — one DataFrame per step +base_time = np.datetime64("2024-01-01T12:00:00") +n_steps = 4 +observation_sequence = [] + +for i in range(n_steps): + step_time = base_time + np.timedelta64(i, "h") + obs_step = pd.DataFrame( + { + "time": [step_time] * 4, + "lat": [30.0, 35.0, 40.0, 45.0], + "lon": [240.0, 250.0, 260.0, 270.0], + "observation": [10.0 + i * 5, 15.0 + i * 5, 20.0 + i * 5, 25.0 + i * 5], + "variable": ["t2m"] * 4, + } + ) + obs_step.attrs = { + "request_time": np.array([step_time]), + "request_lead_time": np.array([np.timedelta64(0, "h")]), + } + observation_sequence.append(obs_step) + +# Run generator +gen = model.create_generator() +gen.send(None) # Prime +results = [] +for obs_step in observation_sequence: + result = gen.send(obs_step) + results.append(result) +gen.close() + +# Plot sequence +var = results[0].coords["variable"].values[0] # Plot first variable +lat = results[0].coords["lat"].values +lon = results[0].coords["lon"].values + +fig, axes = plt.subplots(1, len(results), figsize=(5 * len(results), 4)) +if len(results) == 1: + axes = [axes] + +for ax, (i, result) in zip(axes, enumerate(results)): + data_2d = result.sel(variable=var).isel(time=0).values + if hasattr(data_2d, "get"): + data_2d = data_2d.get() # cupy -> numpy + im = ax.contourf(lon, lat, data_2d, cmap="turbo", levels=20) + ax.set_title(f"Step {i}") + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + plt.colorbar(im, ax=ax, shrink=0.8) + +plt.suptitle(f"<ModelName> — generator sequence ({var})", y=1.02) +plt.tight_layout() +plt.savefig("sanity_check_da_generator.png", dpi=150, bbox_inches="tight") +print("Saved: sanity_check_da_generator.png") +``` + +### 3d. Create side-by-side comparison script + +Create a script that runs both the reference implementation and the +E2S wrapper with identical inputs and produces a side-by-side plot: + +```python +"""Side-by-side comparison: reference vs Earth2Studio for <ModelName>. + +This script is for validation only — do NOT commit to the repo. +Fill in the TODO sections below before running. +""" +import matplotlib.pyplot as plt +import numpy as np + +# TODO: Import and initialize both reference and E2S models +# TODO: Prepare identical input observations (DataFrame) +raise NotImplementedError( + "Fill in the reference and E2S model code above before running." +) + +# ref_output = ... # Run reference model +# e2s_output = ... # Run E2S wrapper + +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + +# Panel 1: Reference output +# axes[0].contourf(lon, lat, ref_data, cmap="turbo") +# axes[0].set_title("Reference") + +# Panel 2: E2S output +# axes[1].contourf(lon, lat, e2s_data, cmap="turbo") +# axes[1].set_title("Earth2Studio") + +# Panel 3: Difference +# axes[2].contourf(lon, lat, ref_data - e2s_data, cmap="RdBu_r") +# axes[2].set_title("Difference (Ref - E2S)") + +plt.suptitle("<ModelName> — reference vs Earth2Studio") +plt.tight_layout() +plt.savefig("comparison_<model_name>.png", dpi=150, bbox_inches="tight") +print("Saved: comparison_<model_name>.png") +``` + +### 3e. Run comparison and sanity-check scripts + +Execute the scripts: + +```bash +uv run python reference_comparison_<model_name>.py +uv run python sanity_check_da_spatial.py +uv run python sanity_check_da_overlay.py +uv run python sanity_check_da_generator.py +``` + +Verify that: + +- The reference comparison passes (all assertions hold) +- All sanity-check scripts run without errors +- Output PNGs are generated +- Metrics are printed (max abs diff, max rel diff, correlation) + +### 3f. **[CONFIRM — Sanity-Check & Comparison]** + +**You MUST ask the user to visually inspect the generated plot(s) +before proceeding.** Do not skip this step even if the scripts ran +without errors — a successful run does not guarantee the plots are +correct (e.g., empty axes, wrong colorbar range, garbled data). + +Tell the user the absolute path to the generated image file(s) and +the reference comparison metrics, then ask them to inspect: + +> The reference comparison and sanity-check scripts ran successfully. +> +> **Reference comparison metrics:** +> +> - Max absolute difference: `<value>` +> - Max relative difference: `<value>` +> - Correlation: `<value>` +> +> **Sanity-check plots saved to:** +> +> 1. `/absolute/path/to/sanity_check_da_spatial.png` — gridded output +> 2. `/absolute/path/to/sanity_check_da_overlay.png` — observations on grid +> 3. `/absolute/path/to/sanity_check_da_generator.png` — generator sequence +> +> **Please open these images and confirm they look correct.** Check: +> +> 1. Data is visible on the axes (not blank/empty) +> 2. Values are in physically reasonable ranges +> 3. No obvious artifacts (all-NaN regions, garbled values) +> 4. Spatial patterns look plausible (geographic features visible) +> 5. Observation overlay: scatter points visible on top of grid +> 6. Generator sequence: evolution across steps is coherent +> +> Do the plots look correct and do the reference comparison metrics +> look acceptable? + +**Do not proceed to Step 4 until the user explicitly confirms.** If +the user reports problems, debug and fix the issue, re-run the +scripts, and ask the user to inspect again. + +--- + +## Step 4 — Branch, Commit & Open PR + +### **[CONFIRM — Ready to Submit]** + +Before proceeding, confirm with the user: + +> All implementation and validation steps are complete: +> +> - DA model class implemented with correct method ordering +> - `init_coords` returns correct type (None for stateless or tuple for stateful) +> - `input_coords` returns tuple of FrameSchema +> - `output_coords` returns tuple with `request_time` +> - `create_generator` uses send protocol with GeneratorExit handling +> - `validate_observation_fields` used for DataFrame inputs +> - cupy/cudf optional import pattern present +> - Registered in `earth2studio/models/da/__init__.py` +> - Documentation added to `docs/modules/models.rst` +> - Reference URLs included in class docstrings +> - CHANGELOG.md updated +> - Format, lint, and license checks pass +> - Unit tests written and passing with >= 90% coverage +> - Dependencies in pyproject.toml confirmed +> - Reference comparison passes with acceptable tolerance +> - Sanity-check plots generated and confirmed by user +> +> Ready to create a branch, commit, and prepare a PR? + +### 4a. Create branch and commit + +```bash +git checkout -b feat/da-model-<name> +git add earth2studio/models/da/<filename>.py \ + earth2studio/models/da/__init__.py \ + test/models/da/test_<filename>.py \ + pyproject.toml \ + CHANGELOG.md \ + docs/modules/models.rst +git commit -m "feat: add <ClassName> assimilation model + +Add <ClassName> data assimilation model for <brief description>. +Includes unit tests and documentation." +``` + +Do **NOT** add the sanity-check scripts, comparison scripts, or +their output images. + +### 4b. Identify the fork remote and push branch + +The working repository is typically a **fork** of +`NVIDIA/earth2studio`. Before pushing, confirm which git remote +points to the user's fork: + +```bash +git remote -v +``` + +Ask the user: + +> Which git remote is your fork of `NVIDIA/earth2studio`? +> (Usually `origin` — e.g., `git@github.com:<user>/earth2studio.git`) + +Then push the feature branch to the **fork** remote: + +```bash +git push -u <fork-remote> feat/da-model-<name> +``` + +### 4c. Open Pull Request (fork -> NVIDIA/earth2studio) + +> **Important:** PRs must be opened **from the fork** to the +> **upstream source repository** `NVIDIA/earth2studio`. The branch +> lives on the fork; the PR targets `main` on the upstream repo. + +Use `gh pr create` with explicit `--repo` and `--head` flags: + +```bash +gh pr create \ + --repo NVIDIA/earth2studio \ + --base main \ + --head <fork-owner>:feat/da-model-<name> \ + --title "feat: add <ClassName> assimilation model" \ + --body "..." +``` + +Where `<fork-owner>` is the GitHub username that owns the fork. + +The PR body should follow this DA-model-specific template: + +````markdown +## Description + +Add `<ClassName>` data assimilation model for <brief description>. + +Closes #<issue_number> (if applicable) + +### Model details + +| Property | Value | +|---|---| +| **Model type** | Stateless / Stateful | +| **Input format** | DataFrame / DataArray / Mixed | +| **Output format** | DataArray / DataFrame | +| **Observation schema** | time, lat, lon, observation, variable | +| **Grid specification** | lat-lon / HRRR / HealPix / etc. | +| **Time tolerance** | <default value> | +| **cudf/cupy support** | Yes / No | +| **Reference** | <link to paper/repo> | + +### Dependencies added + +| Package | Version | License | License URL | Reason | +|---|---|---|---|---| +| `<package>` | `>=X.Y` | <License> | [link](<URL>) | <reason> | + +*(or "No new dependencies")* + +### Reference comparison + +- Max absolute difference: <value> +- Max relative difference: <value> +- Correlation: <value> + +### Validation + +See sanity-check plots in PR comments below. + +## Checklist + +- [x] I am familiar with the [Contributing Guidelines][contrib]. +- [x] New or existing tests cover these changes. +- [x] The documentation is up to date with these changes. +- [x] The [CHANGELOG.md][changelog] is up to date with these changes. +- [ ] An [issue][issues] is linked to this pull request. +- [ ] Assess and address Greptile feedback (AI code review bot). + +[contrib]: https://github.com/NVIDIA/earth2studio/blob/main/CONTRIBUTING.md +[changelog]: https://github.com/NVIDIA/earth2studio/blob/main/CHANGELOG.md +[issues]: https://github.com/NVIDIA/earth2studio/issues +```` + +### 4d. Post sanity-check as PR comment + +After the PR is created, post the sanity-check visualization as a +separate **PR comment** so it is immediately visible to reviewers. + +#### Image upload limitation + +**GitHub has no CLI or REST API for uploading images to PR comments.** +The only way to embed an image is via the browser's drag-and-drop +editor or by referencing an already-hosted URL. + +**Practical workflow:** + +1. Write the comment body to a temp file (avoids shell quoting issues + with heredocs containing backticks and markdown). +2. Post the comment **without** the image — include the validation + table, reference comparison metrics, the full sanity-check script, + and a placeholder line. +3. Tell the user to drag the image into the browser editor. + +```bash +# 1. Write body to a temp file (use your editor tool, not heredoc) + +# 2. Post the comment +gh api -X POST repos/NVIDIA/earth2studio/issues/<PR_NUMBER>/comments \ + -F "body=@/tmp/pr_comment_body.md" \ + --jq '.html_url' +``` + +Do **not** waste time trying `curl` uploads, GraphQL file mutations, +or the `uploads.github.com` asset endpoint — they do not work for +issue/PR comment images. + +#### Comment content template + +```markdown +## Sanity-Check Validation + +**Model:** `<ClassName>` — <brief description> +**Type:** Stateless / Stateful +**Test environment:** <GPU model or CPU> + +### Reference Comparison + +| Metric | Value | +|--------|-------| +| Max absolute difference | <value> | +| Max relative difference | <value> | +| Correlation | <value> | + +### Model Summary + +| Property | Value | +|----------|-------| +| Model type | Stateless / Stateful | +| Input format | DataFrame / DataArray / Mixed | +| Output format | DataArray / DataFrame | +| Observation schema | time, lat, lon, observation, variable | +| Output grid | lat-lon N x M / HRRR / HealPix | +| Output variables | <list or count> | +| Time tolerance | <default value> | +| cudf/cupy support | Yes / No | +| Inference time | ~XX ms | + +**Key findings:** +- <bullet summarizing numerical agreement with reference> +- <bullet on output quality / physical reasonableness> +- <bullet on performance or notable behavior> + +> **TODO:** Attach sanity-check images by editing this comment in +> the browser. + +<details> +<summary>Sanity-check scripts (click to expand)</summary> + +```python +PASTE THE FULL WORKING SCRIPTS HERE — not truncated excerpts. +The scripts must be copy-pasteable and produce the plots end-to-end. +``` + +</details> +``` + +**Important:** Always paste the **complete, runnable** scripts — not +shortened versions. Reviewers should be able to reproduce the plots +by copying the scripts directly. + +#### Finalize + +After posting, inform the user of: + +1. The comment URL +2. The local paths to the image files for manual attachment +3. Instructions: *"Edit the comment in your browser and drag the + image files into the editor to embed them."* + +> **Note:** The sanity-check images and scripts are for PR review +> purposes only — they must NOT be committed to the repository. + +--- + +## Step 5 — Automated Code Review (Greptile) + +After the PR is created and pushed, an automated code review from +**greptile-apps** (Greptile) will be posted as PR review comments. +Wait for this review, then process the feedback. + +### 5a. Wait for Greptile review + +Poll for review comments from `greptile-apps[bot]` every 30 seconds +for up to **5 minutes**. Time out gracefully if no review arrives: + +```bash +# Poll loop — check every 30s, timeout after 5 minutes (10 attempts) +for i in $(seq 1 10); do + REVIEW_ID=$(gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/reviews \ + --jq '.[] | select(.user.login == "greptile-apps[bot]") | .id' 2>/dev/null) + if [ -n "$REVIEW_ID" ]; then + echo "Greptile review found: $REVIEW_ID" + break + fi + echo "Attempt $i/10 — no review yet, waiting 30s..." + sleep 30 +done +``` + +If no review after 5 minutes, inform the user: + +> Greptile hasn't posted a review after 5 minutes. This can happen if +> the review bot is busy or the PR hasn't triggered it. You can: +> +> 1. Ask me to check again later +> 2. Skip this step and proceed without automated review +> 3. Manually request a review from Greptile on the PR page + +### 5b. Pull and parse review comments + +Once the review is posted, fetch all comments: + +```bash +# Get all review comments on the PR +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/comments \ + --jq '.[] | select(.user.login == "greptile-apps[bot]") | + {path: .path, line: .diff_hunk, body: .body}' +``` + +Also fetch the top-level review body: + +```bash +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/reviews \ + --jq '.[] | select(.user.login == "greptile-apps[bot]") | .body' +``` + +### 5c. Categorize and present to user + +Parse each comment and categorize it: + +| Category | Description | Default action | +| --------------------- | --------------------------------- | -------------- | +| **Bug / correctness** | Logic errors, wrong behavior | Fix | +| **Style / convention** | Naming, formatting, patterns | Fix if valid | +| **Performance** | Inefficiency, resource waste | Evaluate | +| **Documentation** | Missing/wrong docs, docstrings | Fix | +| **Suggestion** | Alternative approach, nice-to-have | User decides | +| **False positive** | Incorrect or irrelevant feedback | Dismiss | + +### **[CONFIRM — Review Triage]** + +Present each comment to the user in a summary table: + +```markdown +| # | File | Line | Category | Summary | Proposed Action | +|---|------|------|----------|---------|-----------------| +| 1 | <model>.py | 142 | Bug | Missing null check | Fix: add guard | +| 2 | <model>.py | 305 | Style | Use f-string | Fix: convert | +| 3 | <model>.py | 45 | Suggestion | Add type alias | Skip: not needed | +| ... | ... | ... | ... | ... | ... | +``` + +For each comment, briefly explain: + +- What Greptile flagged +- Whether you agree or disagree (with reasoning) +- Your proposed fix (or why to skip) + +Ask the user to confirm which comments to address. The user may: + +- Accept all proposed fixes +- Select specific fixes +- Override your recommendation on any comment +- Add their own fixes + +### 5d. Implement fixes + +For each accepted fix: + +1. Make the code change +2. Run `make format && make lint` after all fixes +3. Run the relevant tests: + + ```bash + uv run python -m pytest test/models/da/test_<filename>.py -v --timeout=60 + ``` + +4. Commit with a message like: + + ```bash + git commit -m "fix: address code review feedback (Greptile)" + ``` + +### 5e. Respond to review comments + +For each Greptile comment, post a reply on the PR: + +**For fixed comments:** + +```bash +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/comments/<COMMENT_ID>/replies \ + -f body="Fixed in <commit_sha>. <brief description of fix>" +``` + +**For dismissed comments (false positives / won't fix):** + +```bash +gh api repos/NVIDIA/earth2studio/pulls/<PR_NUMBER>/comments/<COMMENT_ID>/replies \ + -f body="Won't fix — <brief justification>" +``` + +### 5f. Push and resolve + +```bash +git push <fork-remote> feat/da-model-<name> +``` + +After pushing, resolve all addressed review threads if possible. + +Inform the user of the final state: + +- How many comments were fixed +- How many were dismissed (with reasons) +- Any remaining open threads + +--- + +## Reminders + +- **DO** use the repo's local `uv` `.venv` to run Python with + `uv run python` +- **DO NOT** commit sanity-check/comparison scripts or images to + the repo +- **DO** use `loguru.logger` for logging, never `print()`, inside + `earth2studio/` +- **DO** ensure all public functions have full type hints (mypy-clean) +- **DO** maintain alphabetical order in `__init__.py` exports, + RST file entries, and CHANGELOG entries +- **DO** return tuples from `input_coords` and `output_coords` +- **DO** use `FrameSchema` for tabular inputs, `CoordSystem` for + gridded outputs +- **DO** validate `request_time` from `obs.attrs` +- **DO** use `validate_observation_fields`, `filter_time_range`, + `dfseries_to_torch` from `earth2studio.models.da.utils` +- **DO** prime generator with `yield None` and handle `GeneratorExit` +- **DO** return cupy arrays on GPU, numpy on CPU +- **DO** register `device_buffer` and expose `device` property +- **DO** follow the canonical DA method ordering: + `__init__`, `device` property, `init_coords`, `input_coords`, + `output_coords`, `load_default_package`, `load_model`, `to`, + private methods, `__call__`, `create_generator` +- **DO** include reference URLs in class docstrings +- **DO** always update CHANGELOG.md under the current unreleased + version +- **DO** add the model to `docs/modules/models.rst` in the + `earth2studio.models.da` section (alphabetical order) +- **DO NOT** use `@batch_func` or `@batch_coords` — these are + px/dx conventions only and do not apply to DA models +- **DO NOT** use `PrognosticMixin` — DA models do not time-step +- **DO NOT** use `create_iterator` — DA uses `create_generator` + with the send protocol +- **DO NOT** assume tensor inputs — DA inputs are DataFrames and/or + DataArrays +- **DO NOT** make a general base class with intent to reuse the + wrapper across models +- **DO NOT** over-populate the `load_model()` API — only expose + essential parameters +- **NEVER** commit, hardcode, or include API keys, secrets, tokens, + or credentials in source code, sample scripts, commit messages, + PR descriptions, or any file tracked by git diff --git a/docs/superpowers/plans/2026-04-13-assimilation-model-skills.md b/docs/superpowers/plans/2026-04-13-assimilation-model-skills.md new file mode 100644 index 000000000..d7161d94e --- /dev/null +++ b/docs/superpowers/plans/2026-04-13-assimilation-model-skills.md @@ -0,0 +1,582 @@ +# Assimilation Model Skills Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use +> superpowers:subagent-driven-development (recommended) or +> superpowers:executing-plans to implement this plan task-by-task. +> Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Create two Claude skills (`create-assimilation-wrapper` +and `validate-assimilation-wrapper`) that guide agents through +building and validating Earth2Studio data assimilation model +wrappers. + +**Architecture:** Each skill is a single Markdown file in +`.claude/skills/<name>/SKILL.md` with YAML frontmatter, numbered +steps with `[CONFIRM]` gates, complete code blocks, and a +Reminders section. The create skill covers Steps 0-8 +(implementation only), the validate skill covers Steps 1-5 +(testing, comparison, PR, Greptile). + +**Tech Stack:** Markdown skill files, Python code blocks +referencing Earth2Studio DA patterns (pandas/cudf DataFrames, +xarray DataArrays, FrameSchema, CoordSystem, send-protocol +generators). + +--- + +## Task 1: Create `create-assimilation-wrapper` SKILL.md + +**Files:** + +- Create: `.claude/skills/create-assimilation-wrapper/SKILL.md` + +This task creates the full skill file for guiding creation of DA +model wrappers. The file should be modeled on the existing +`create-diagnostic-wrapper/SKILL.md` structure but with +DA-specific content throughout. + +**Key reference files to read before writing:** + +- `.cursor/rules/e2s-013-assimilation-models.mdc` (577 lines) + — authoritative DA implementation rules +- `earth2studio/models/da/base.py` (188 lines) + — AssimilationModel Protocol +- `earth2studio/models/da/interp.py` (593 lines) + — simple stateless DA example +- `earth2studio/models/da/utils.py` + — DA utilities (validate, filter, convert) +- `.claude/skills/create-diagnostic-wrapper/SKILL.md` + — structural template (layout, step structure, CONFIRM gates) +- `docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md` + — approved spec + +**Structure to follow:** + +The file must contain: + +1. **YAML Frontmatter** (lines 1-5): + + ```yaml + --- + name: create-assimilation-wrapper + description: >- + Create a new Earth2Studio data assimilation model (da) + wrapper from a reference inference script or repository. + DA models ingest sparse observations (DataFrames) and/or + gridded state arrays (DataArrays) and produce analysis + output — they do NOT use @batch_func, @batch_coords, + or PrognosticMixin. + argument-hint: >- + URL or local path to reference inference script/repo + (optional — will be asked if not provided) + --- + ``` + +2. **Title and preamble** — "Create Assimilation Model Wrapper", + CONFIRM gate explanation, uv/venv environment note. + +3. **Critical Differences from px/dx table** — The full table + from the spec (18 rows). This MUST be placed prominently + near the top, right after the preamble. It's the most + important reference for any agent using this skill. + +4. **Step 0 — Obtain Reference & Analyze Model** + - 0a: Accept `$ARGUMENTS` or ask user (same as dx skill) + - 0b: Analyze the reference to determine: + - Input types: DataFrame only, DataArray only, or mixed + - Output types: DataArray (gridded), DataFrame (tabular) + - Stateful vs stateless (state between steps?) + - Whether `@torch.inference_mode()` is safe + - Dependency packages + - No NN/physics branching — all DA models get the same flow + - Present analysis summary to user + - `[CONFIRM — Model Analysis]` + +5. **Step 1 — Examine Reference & Propose Dependencies** + - 1a: Identify packages (physicsnemo, scipy, healpy, etc.) + - 1b: Propose pyproject.toml group `da-<model-name>` + - Highlight cudf/cupy as optional GPU acceleration + - `[CONFIRM — Dependencies]` + +6. **Step 2 — Add Dependencies to pyproject.toml** + +7. **Step 3 — Create Skeleton Class File** + - File path: `earth2studio/models/da/<filename>.py` + - License header (SPDX Apache-2.0, 2024-2026) + - Dual inheritance: `torch.nn.Module + AutoModelMixin` + (NO PrognosticMixin) + - Class-level `@check_optional_dependencies()` decorator + - Optional dependency try/except with + `OptionalDependencyFailure("da-<name>")` + - cupy/cudf optional imports + + **Canonical DA method ordering** (MUST be enforced): + 1. `__init__` — register `device_buffer`, store model, + normalize time tolerance + 2. `device` property — `return self.device_buffer.device` + 3. `init_coords` — `None` for stateless, + `tuple[CoordSystem]` for stateful + 4. `input_coords` — `tuple[FrameSchema]` for DataFrame, + `tuple[CoordSystem]` for DataArray + 5. `output_coords` — accept `input_coords` tuple + + `request_time` kwarg, return tuple + 6. `load_default_package` — classmethod + 7. `load_model` — classmethod with + `@check_optional_dependencies()` + 8. `to` — device management, return `AssimilationModel` + 9. Private/support methods + 10. `__call__` — stateless forward + 11. `create_generator` — bidirectional generator with send + + **Complete skeleton code block** — must include all methods + with pseudocode `# TODO` comments. Use the complete example + from `e2s-013-assimilation-models.mdc` lines 420-544 as + the primary template. + + **Key differences from dx/px skeletons** to document: + + - No `@batch_func()`, no `@batch_coords()`, + no `PrognosticMixin` + - `input_coords` and `output_coords` return **tuples** + (even for single) + - `FrameSchema` for tabular inputs, + `CoordSystem` for gridded outputs + - `device` property instead of + `next(self.parameters()).device` + - `create_generator` with send protocol instead of + `create_iterator` + - `validate_observation_fields()` instead of + `handshake_coords`/`handshake_dim` for DataFrames + - `request_time` from `obs.attrs`, not a coordinate dim + + - `[CONFIRM — Skeleton]` + +8. **Step 4 — Implement Coordinate System** + - 4a: `init_coords()` — return `None` for stateless, + or tuple of `CoordSystem`/`FrameSchema` for stateful. + Show both patterns. + - 4b: `input_coords()` — return tuple of `FrameSchema`. + Show the standard 5-column schema: time, lat, lon, + observation, variable. Explain `np.empty(0, dtype=...)` + for unbounded vs `np.array([...])` for enumerated. + - 4c: `output_coords()` — accept tuple + `request_time`, + return tuple of `CoordSystem`. Show handshake helpers + only for `CoordSystem` inputs. Show + `validate_observation_fields()` for `FrameSchema`. + - `[CONFIRM — Coordinates]` + +9. **Step 5 — Implement Forward Pass** + - 5a: `__call__` — extract `request_time` from + `obs.attrs`, validate with + `validate_observation_fields()`, filter with + `filter_time_range()`, convert with + `dfseries_to_torch()`, run model, build + `xr.DataArray` with cupy/numpy based on device + - 5b: `create_generator` — two patterns: + - Stateless: `observations = yield None`, loop calling + `self.__call__(observations)`, handle `GeneratorExit` + - Stateful: accept init args, + `obs = yield initial_state`, loop with state updates, + handle `GeneratorExit` + - Code blocks for both patterns + - `[CONFIRM — Forward Pass]` + +10. **Step 6 — Implement Model Loading** + - `load_default_package` — Package with cache options + - `load_model` — `@check_optional_dependencies()`, + `package.resolve()`, `.eval()`, `.requires_grad_(False)` + - `.to()` — `super().to(device)`, return + `AssimilationModel` + - `[CONFIRM — Model Loading]` + +11. **Step 7 — Register in `__init__.py`** + - Add to `earth2studio/models/da/__init__.py` (alpha order) + - Add to `__all__` if present + +12. **Step 8 — Verify Style/Format/Lint** + - `make format`, `make lint`, `make license` + - Common mypy issues for DA: DataFrame type annotations, + optional cudf types + +13. **Reminders section** — Complete DA-specific DO/DON'T rules: + - DO return tuples from `input_coords`/`output_coords` + - DO use `FrameSchema` for tabular, `CoordSystem` for grid + - DO validate `request_time` from `obs.attrs` + - DO use `validate_observation_fields()`, + `filter_time_range()`, `dfseries_to_torch()` + - DO prime generator with `yield None` and handle + `GeneratorExit` + - DO return cupy arrays on GPU, numpy on CPU + - DO register `device_buffer` and expose `device` property + - DO use `loguru.logger`, never `print()` + - DO ensure all public functions have full type hints + - DO run `make format && make lint` before finalizing + - DO use `uv run python` for all Python commands + - DO NOT use `@batch_func()` or `@batch_coords()` + - DO NOT use `PrognosticMixin` + - DO NOT use `create_iterator` — DA uses + `create_generator` with send protocol + - DO NOT assume tensor inputs — inputs are DFs/DataArrays + - DO NOT forget cudf/cupy optional import pattern + - DO NOT make a general base class with intent to reuse + - DO NOT over-populate `load_model()` API + +- [ ] **Step 1: Read all reference files** + +Read: + +- `.cursor/rules/e2s-013-assimilation-models.mdc` (577 lines) +- `earth2studio/models/da/base.py` (188 lines) +- `earth2studio/models/da/interp.py` (593 lines — primary ref) +- `earth2studio/models/da/utils.py` (~175 lines) +- `.claude/skills/create-diagnostic-wrapper/SKILL.md` + (structural template) +- `docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md` + (297 lines) + +- [ ] **Step 2: Create directory** + +```bash +mkdir -p .claude/skills/create-assimilation-wrapper +``` + +- [ ] **Step 3: Write SKILL.md** + +Write `.claude/skills/create-assimilation-wrapper/SKILL.md` +with all content described above. The file should be +approximately 1200-1500 lines (similar to the dx create skill +at 1440 lines, but slightly shorter since there is no +NN/physics branching). + +**Self-review checklist before reporting:** + +- [ ] Frontmatter has correct name, description, argument-hint +- [ ] Critical differences table is present near the top +- [ ] All 9 steps (0-8) are present with correct headers +- [ ] 6 CONFIRM gates: Model Analysis, Dependencies, + Skeleton, Coordinates, Forward Pass, Model Loading +- [ ] Complete skeleton code block with all 11 method slots +- [ ] FrameSchema and CoordSystem patterns shown correctly +- [ ] `create_generator` with send protocol (both patterns) +- [ ] `validate_observation_fields`, `filter_time_range`, + `dfseries_to_torch` all shown +- [ ] cupy/cudf optional import and output patterns +- [ ] `device` property pattern (not + `next(self.parameters()).device`) +- [ ] No stale px/dx references (no `@batch_func`, + no `@batch_coords`, no `PrognosticMixin`, + no `create_iterator`, no `handshake_dim` for DF inputs) +- [ ] License header template (SPDX Apache-2.0, 2024-2026) +- [ ] Reminders section with all DO/DON'T rules from spec +- [ ] No placeholder TODOs in prose (skeleton code blocks may + have `# TODO` pseudocode) + +- [ ] **Step 4: Self-review and report** + +Report status: DONE / DONE_WITH_CONCERNS / BLOCKED / +NEEDS_CONTEXT + +--- + +## Task 2: Create `validate-assimilation-wrapper` SKILL.md + +**Files:** + +- Create: `.claude/skills/validate-assimilation-wrapper/SKILL.md` + +This task creates the full validation skill file. Model on the +existing `validate-diagnostic-wrapper/SKILL.md` structure but +with DA-specific content. + +**Key reference files to read before writing:** + +- `.claude/skills/validate-diagnostic-wrapper/SKILL.md` + (831 lines — structural template) +- `.cursor/rules/e2s-013-assimilation-models.mdc` (577 lines) +- `test/models/da/test_da_interp.py` (359 lines — test patterns) +- `earth2studio/models/da/interp.py` (593 lines — reference) +- `docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md` + (297 lines) + +**Structure to follow:** + +1. **YAML Frontmatter**: + + ```yaml + --- + name: validate-assimilation-wrapper + description: >- + Validate a newly created Earth2Studio data assimilation + model wrapper by writing unit tests (90% coverage + required), performing reference comparison with + DataFrame/DataArray outputs, generating sanity-check + plots, and opening a PR with automated code review. + Use after completing create-assimilation-wrapper + Steps 0-8. + argument-hint: >- + Name of the DA model class and test file (optional — + will be inferred from recent changes if not provided) + --- + ``` + +2. **Title and preamble** — uv/venv note, CONFIRM gate + explanation. + +3. **Step 1 — Write Pytest Unit Tests** + + DA-specific test patterns (NOT the same as dx/px tests): + + - `PhooModelName` dummy class — a simple + `torch.nn.Module` that returns known output shapes as + `xr.DataArray`. It must accept `pd.DataFrame` input + and return `xr.DataArray` output (NOT tensor-to-tensor + like dx/px dummies). + - `test_package` fixture — create dummy checkpoint, + save to tmp_path + - **Parametrize over pandas AND cudf** — + `@pytest.fixture` for pandas DataFrame, separate + fixture for cudf with `pytest.skip` guard + - **Parametrize over CPU AND GPU** devices + - `test_model_call` — create DataFrame with obs.attrs, + call model, verify: + - Output is `xr.DataArray` (not torch.Tensor) + - Output dims match output_coords + - Output data type matches device (cupy/numpy) + - `test_generator_protocol` — prime, send, close: + + ```python + gen = model.create_generator() + result = gen.send(None) # Prime + assert result is None # or initial state + da = gen.send(obs_df) # Send observations + assert isinstance(da, xr.DataArray) + gen.close() + ``` + + - `test_init_coords` — verify returns None (stateless) + or tuple (stateful) + - `test_input_coords` — verify returns tuple of + FrameSchema + - `test_time_tolerance` — verify filter_time_range works + - `test_empty_dataframe` — verify graceful handling + - `test_invalid_attrs` — verify raises on missing + `request_time` + - `test_validate_observation_fields` — verify raises on + bad columns + - `test_model_exceptions` — invalid inputs raise + ValueError/KeyError + - `@pytest.mark.package` integration test + + Show complete test file template with all test methods. + + - `[CONFIRM — Package Test]` (for integration test only) + +4. **Step 2 — Run Tests & Achieve 90% Coverage** + - 2a: Run test file with `-v --timeout=60` + - 2b: Coverage with `--slow` + `--cov=earth2studio/models/da/<filename>` + `--cov-report=term-missing` `--cov-fail-under=90` + - DA-specific coverage gaps: GeneratorExit cleanup path, + cudf code paths, empty DataFrame handling, time + tolerance edge cases, obs.attrs validation branches, + cupy vs numpy output paths + - 2c: Optional full suite `make pytest TOX_ENV=test-da` + +5. **Step 3 — Reference Comparison & Sanity-Check** + - 3a: Reference comparison script — compare `__call__` + output AND multi-step generator output against + reference implementation + - For DataArray: max_abs_diff, max_rel_diff, + correlation, torch.allclose + - For DataFrame: row counts, value ranges, spatial + coverage + - 3b: Model summary table — input schema, output grid, + variables, stateful/stateless, observation types, + cudf support + - 3c: Three sanity-check plot templates: + 1. **Spatial assimilated output** — `contourf` of + gridded DataArray output + 2. **Observation overlay** — scatter of input DataFrame + observations on assimilated grid (unique to DA) + 3. **Generator sequence** — multi-step evolution + - 3d: Side-by-side comparison scripts (ref vs E2S) + - 3e-3f: Run scripts, user confirms plots + - `[CONFIRM — Sanity-Check & Comparison]` + +6. **Step 4 — Branch, Commit & Open PR** + - `[CONFIRM — Ready to Submit]` with DA checklist + - 4a: Create branch `feat/da-model-<name>`, commit + - 4b: Identify fork remote, push + - 4c: Open PR with DA-specific template: + - Model type: Stateless / Stateful + - Input format: DataFrame / DataArray / Mixed + - Output format: DataArray / DataFrame + - Observation schema: columns and types + - Grid specification + - Time tolerance + - cudf/cupy support + - Reference comparison metrics + - 4d: Post sanity-check as PR comment + +7. **Step 5 — Automated Code Review (Greptile)** + - Same polling/triage/fix pattern as dx/px validate skills + - 5a: Poll for `greptile-apps[bot]` review + - 5b: Fetch + parse comments + - 5c: Categorize and present triage table + - `[CONFIRM — Review Triage]` + - 5d: Implement fixes + - 5e: Respond to comments + - 5f: Push and resolve + +8. **Reminders section** — Same DA-specific DO/DON'T as + create skill, plus: + - DO NOT commit sanity-check scripts/images + - DO NOT commit secrets/credentials + - DO maintain alphabetical order in `__init__.py` exports + - NEVER call `loop.set_default_executor()` + +- [ ] **Step 1: Read all reference files** + +Read: + +- `.claude/skills/validate-diagnostic-wrapper/SKILL.md` + (831 lines) +- `.cursor/rules/e2s-013-assimilation-models.mdc` (577 lines) +- `test/models/da/test_da_interp.py` (359 lines) +- `docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md` + (297 lines) + +- [ ] **Step 2: Create directory** + +```bash +mkdir -p .claude/skills/validate-assimilation-wrapper +``` + +- [ ] **Step 3: Write SKILL.md** + +Write `.claude/skills/validate-assimilation-wrapper/SKILL.md` +with all content described above. Target approximately +900-1100 lines (similar to dx validate at 831 lines, but +slightly longer due to DA-specific test templates). + +**Self-review checklist before reporting:** + +- [ ] Frontmatter has correct name, description, argument-hint +- [ ] All 5 steps present with correct headers +- [ ] 4 CONFIRM gates: Package Test, Sanity-Check & + Comparison, Ready to Submit, Review Triage +- [ ] Complete test file template with DA-specific patterns + (DataFrame fixtures, cudf parametrize, generator protocol + test, obs.attrs test) +- [ ] 90% coverage requirement with DA-specific gap list +- [ ] 3 sanity-check plot templates (spatial, observation + overlay, generator sequence) +- [ ] Reference comparison script with DataArray AND + DataFrame metrics +- [ ] DA-specific PR template with all required fields +- [ ] Greptile polling/triage workflow +- [ ] Reminders section with all DA-specific rules +- [ ] No stale dx/px references (no `@batch_func`, no tensor + assertions, no `handshake_dim` for DataFrame assertions) +- [ ] No data-source leftovers + +- [ ] **Step 4: Self-review and report** + +Report status: DONE / DONE_WITH_CONCERNS / BLOCKED / +NEEDS_CONTEXT + +--- + +## Task 3: Final Verification and Commit + +**Files:** + +- Verify: `.claude/skills/create-assimilation-wrapper/SKILL.md` +- Verify: `.claude/skills/validate-assimilation-wrapper/SKILL.md` + +**Depends on:** Tasks 1 and 2 + +- [ ] **Step 1: Verify both files exist and are non-empty** + +```bash +ls -la .claude/skills/create-assimilation-wrapper/SKILL.md +ls -la .claude/skills/validate-assimilation-wrapper/SKILL.md +wc -l .claude/skills/create-assimilation-wrapper/SKILL.md +wc -l .claude/skills/validate-assimilation-wrapper/SKILL.md +``` + +Expected: create skill ~1200-1500 lines, validate ~900-1100. + +- [ ] **Step 2: Verify frontmatter is parseable** + +```bash +head -6 .claude/skills/create-assimilation-wrapper/SKILL.md +head -6 .claude/skills/validate-assimilation-wrapper/SKILL.md +``` + +Both must start with `---` and have `name:`, `description:`, +`argument-hint:`. + +- [ ] **Step 3: Check for stale placeholders** + +```bash +grep -n "TODO" \ + .claude/skills/create-assimilation-wrapper/SKILL.md \ + | grep -v "# TODO" | head -20 +grep -n "TODO" \ + .claude/skills/validate-assimilation-wrapper/SKILL.md \ + | grep -v "# TODO" | head -20 +grep -n "TBD" \ + .claude/skills/create-assimilation-wrapper/SKILL.md +grep -n "TBD" \ + .claude/skills/validate-assimilation-wrapper/SKILL.md +``` + +No stale TODOs outside of skeleton code blocks. No TBDs. + +- [ ] **Step 4: Verify no stale px/dx references** + +```bash +grep -n \ + "@batch_func\|@batch_coords\|PrognosticMixin" \ + .claude/skills/create-assimilation-wrapper/SKILL.md \ + | head -20 +grep -n \ + "@batch_func\|@batch_coords\|PrognosticMixin" \ + .claude/skills/validate-assimilation-wrapper/SKILL.md \ + | head -20 +``` + +Any matches must be in "DO NOT" / negative context only. + +- [ ] **Step 5: Cross-reference spec coverage** + +Read spec file and verify every requirement has +corresponding content in the skill files: + +- 6 CONFIRM gates in create skill +- 4 CONFIRM gates in validate skill +- Critical differences table in create skill +- All Reminders items present + +- [ ] **Step 6: Commit** + +```bash +git add \ + .claude/skills/create-assimilation-wrapper/SKILL.md \ + .claude/skills/validate-assimilation-wrapper/SKILL.md +git commit -m "feat: add create and validate assimilation wrapper skills + +Add create-assimilation-wrapper (Steps 0-8) and +validate-assimilation-wrapper (Steps 1-5) skills for +guiding DA model wrapper implementation and validation. +DA-specific patterns: FrameSchema, send-protocol generators, +cudf/cupy support, validate_observation_fields." +``` + +- [ ] **Step 7: Verify commit** + +```bash +git log --oneline -1 +git status +``` diff --git a/docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md b/docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md new file mode 100644 index 000000000..c25468ca5 --- /dev/null +++ b/docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md @@ -0,0 +1,311 @@ +# Assimilation Model Skills Design + +## Goal + +Create two Claude skills for Earth2Studio data assimilation (DA) models: + +1. `create-assimilation-wrapper` — guide creation of a new DA model wrapper +2. `validate-assimilation-wrapper` — guide validation, PR, and code review + +These follow the established create+validate pattern from the diagnostic (dx) and +prognostic (px) skill pairs, but with **fundamental I/O differences** that make DA +models distinct. + +## Critical Differences from px/dx Models + +DA models are **not** tensor-in / tensor-out like px and dx models. The entire I/O +contract is different: + +| Aspect | px / dx | da | +| ------------------- | ------------------------------ | ---------------------------------------- | +| **Primary input** | `Tensor` + `CoordSystem` | `pd.DataFrame` or `xr.DataArray` | +| **Primary output** | `Tensor` + `CoordSystem` | `xr.DataArray` or `pd.DataFrame` | +| **Coord schemas** | `CoordSystem` only | `FrameSchema` + `CoordSystem` | +| **Coord returns** | Single dict | **Tuple** (even for single) | +| **Batch handling** | `@batch_func` / `@batch_coords` | **Neither** — N/A for DataFrame I/O | +| **PrognosticMixin** | Yes (px) / No (dx) | **No** | +| **Time integration** | `create_iterator` yields | `create_generator` with **send** | +| **Generator prime** | Yields initial condition | Yields `None` / state, `.send()` | +| **Generator cleanup** | N/A | Must handle `GeneratorExit` | +| **Init data** | No | Optional via `init_coords()` | +| **Time metadata** | Coordinate dimension | `obs.attrs["request_time"]` | +| **GPU data** | `Tensor` on device | **cupy** arrays, **cudf** DFs | +| **Input validation** | `handshake_dim`/`handshake_coords` | `validate_observation_fields()` | +| **Time filtering** | N/A | `filter_time_range()` | +| **Tensor conversion** | Already tensors | `dfseries_to_torch()` | +| **Device tracking** | `device_buffer` / `parameters()` | `device_buffer` + `@property` | + +These differences pervade every step of the skill — skeleton, coordinates, forward +pass, generator, testing, and validation. + +## Approach + +Single template for the general **stateful** case, with clear notes on simplifications +for **stateless** models (those returning `None` from `init_coords()`). No NN/physics +branching — DA models are all NN-backed or algorithmic with external packages. + +Reference input: user provides reference script/repo URL or local path (same as px +skill). + +--- + +## Skill 1: `create-assimilation-wrapper` + +### Frontmatter (create) + +```yaml +--- +name: create-assimilation-wrapper +description: >- + Create a new Earth2Studio data assimilation model (da) wrapper from a + reference inference script or repository. DA models ingest sparse + observations (DataFrames) and/or gridded state arrays (DataArrays) and + produce analysis output — they do NOT use @batch_func, @batch_coords, + or PrognosticMixin. +argument-hint: URL or local path to reference inference script/repo +--- +``` + +### Steps 0-8 + +The create skill handles **implementation only** — skeleton, coordinates, forward +pass, model loading, registration, and code quality. All testing, validation, +comparison scripts, and PR submission live in the validate skill. + +#### Step 0 — Obtain Reference Script + +- Accept `$ARGUMENTS` or ask user for reference +- Analyze: detect input types (DataFrame? DataArray?), output types, whether + stateful (needs init data) or stateless, dependency packages, model architecture +- Determine if `@torch.inference_mode()` is safe (no gradient flow needed) +- Present summary to user + +No NN/physics branching — all DA models get the same flow. + +**[CONFIRM — Model Analysis]** + +#### Step 1 — Examine Reference and Propose Dependencies + +- Identify Python packages (physicsnemo, scipy, healpy, cudf, cupy, etc.) +- Propose pyproject.toml dependency group named `da-<model-name>` +- Highlight cudf/cupy as optional GPU acceleration packages + +**[CONFIRM — Dependencies]** + +#### Step 2 — Add Dependencies to pyproject.toml + +#### Step 3 — Create Skeleton Class File + +Dual inheritance: `torch.nn.Module + AutoModelMixin` (NO PrognosticMixin). +Class-level `@check_optional_dependencies()` decorator. + +**Canonical DA method ordering:** + +1. `__init__` — register `device_buffer`, store model params, normalize tolerance +2. `device` property — `return self.device_buffer.device` +3. `init_coords` — `None` for stateless, tuple for stateful +4. `input_coords` — tuple of `FrameSchema` (DF) or `CoordSystem` (DA) +5. `output_coords` — accept `input_coords` tuple + `request_time`, return tuple +6. `load_default_package` — classmethod +7. `load_model` — classmethod with `@check_optional_dependencies()` +8. `to` — device management, return `AssimilationModel` +9. Private/support methods (e.g., `_interpolate`, `_forward`, spatial lookups) +10. `__call__` — stateless forward, accept `*args: pd.DataFrame | xr.DataArray | None` +11. `create_generator` — bidirectional generator with send protocol + +**DA-specific skeleton elements:** + +- `FrameSchema` for observation inputs (time, lat, lon, observation, variable) +- `CoordSystem` for gridded outputs +- `request_time` from `obs.attrs`, NOT a coordinate dimension +- `validate_observation_fields()` call in `__call__` +- `filter_time_range()` for time-window filtering +- `dfseries_to_torch()` for zero-copy DataFrame→tensor +- cupy/cudf support: `cp.asarray()` for GPU output, `.cpu().numpy()` for CPU +- Generator: `yield None` to prime, `observations = yield result`, handle `GeneratorExit` +- `@torch.inference_mode()` unless gradient flow is required (document reason if omitted) +- No `@batch_func()`, no `@batch_coords()`, no `PrognosticMixin` + +**[CONFIRM — Skeleton]** + +#### Step 4 — Implement Coordinate System + +Key differences from px/dx: + +- `init_coords()` returns `None` for stateless models or tuple for stateful +- `input_coords()` returns **tuple** of `FrameSchema` — one per `__call__` arg +- `FrameSchema` keys are DF column names (time, lat, lon, observation, variable) +- `output_coords()` accepts tuple + `request_time`, returns tuple of `CoordSystem` +- Use `handshake_dim`/`handshake_coords`/`handshake_size` only for `CoordSystem` +- Use `validate_observation_fields()` for `FrameSchema` inputs + +**[CONFIRM — Coordinates]** + +#### Step 5 — Implement Forward Pass + +Two methods: + +- `__call__`: Extract `request_time` from `obs.attrs`, validate with + `validate_observation_fields()`, filter with `filter_time_range()`, convert with + `dfseries_to_torch()`, run model, build `xr.DataArray` output with cupy/numpy +- `create_generator`: Prime with `yield None` (stateless) or `yield initial_state` + (stateful), loop with `observations = yield result`, handle `GeneratorExit` + +**[CONFIRM — Forward Pass]** + +**Step 6 — Implement Model Loading** +`load_default_package`, `load_model` with `@check_optional_dependencies()` + +**[CONFIRM — Model Loading]** + +**Step 7 — Register in `__init__.py`** +Add to `earth2studio/models/da/__init__.py` + +**Step 8 — Verify Style/Format/Lint** +`make format`, `make lint`, `make license` + +The create skill ends here. All testing, validation, comparison, and PR work is +handled by the `validate-assimilation-wrapper` skill. + +**Reminders section** — DA-specific DO/DON'T: + +- DO return tuples from `input_coords` and `output_coords` +- DO use `FrameSchema` for tabular inputs, `CoordSystem` for gridded outputs +- DO validate `request_time` from `obs.attrs` +- DO use `validate_observation_fields()`, `filter_time_range()`, `dfseries_to_torch()` +- DO prime generator with `yield None` and handle `GeneratorExit` +- DO return cupy arrays on GPU, numpy on CPU +- DO register `device_buffer` and expose `device` property +- DO NOT use `@batch_func()` or `@batch_coords()` +- DO NOT use `PrognosticMixin` +- DO NOT use `create_iterator` — DA uses `create_generator` with send protocol +- DO NOT assume tensor inputs — inputs are DataFrames/DataArrays +- DO NOT forget cudf/cupy optional import pattern + +--- + +## Skill 2: `validate-assimilation-wrapper` + +### Frontmatter (validate) + +```yaml +--- +name: validate-assimilation-wrapper +description: >- + Validate a newly created Earth2Studio data assimilation model wrapper by + writing unit tests (90% coverage required), performing reference comparison + with DataFrame/DataArray outputs, generating sanity-check plots, and + opening a PR with automated code review. Use after completing + create-assimilation-wrapper Steps 0-8. +argument-hint: Name of the DA model class and test file +--- +``` + +### Steps 1-6 + +#### Step 1 — Write Pytest Unit Tests + +DA-specific test patterns: + +- PhooModelName dummy that returns known output shapes +- **Parametrize over pandas AND cudf** (skip cudf if unavailable) +- **Parametrize over CPU AND GPU** devices +- Test `__call__` with DataFrame input, verify `xr.DataArray` output +- Test generator protocol: prime → send → close +- Test `init_coords` returns correct type (None or tuple) +- Test time tolerance filtering +- Test empty DataFrame handling +- Test invalid `obs.attrs` (missing `request_time`) +- Test `validate_observation_fields` raises on bad columns +- `@pytest.mark.package` integration test + +**[CONFIRM — Package Test]** (for `@pytest.mark.package` integration test only) + +#### Step 2 — Run Tests and Achieve 90% Coverage + +- Run test file with `-v --timeout=60`, all must pass +- Coverage with `--slow`, `--cov`, **`--cov-fail-under=90`** — the new DA model + file must achieve **at least 90% line coverage** +- DA-specific coverage gaps to watch: generator cleanup path (`GeneratorExit`), + cudf code paths, empty DataFrame handling, time tolerance edge cases, + `obs.attrs` validation branches, cupy vs numpy output paths +- If coverage is below 90%, add tests to cover missing lines and re-run + +#### Step 3 — Reference Comparison and Sanity-Check + +3a. Reference comparison — compare `__call__` output AND multi-step generator +output against reference implementation. For DataArray output: compute tolerance +metrics (max abs diff, correlation, allclose). For DataFrame output: compare row +counts, value ranges, spatial coverage. + +3b. Model summary table — input schema, output grid, variables, stateful/stateless, +observation types, cudf support. + +3c. Three sanity-check plot templates: + +1. **Spatial assimilated output** — contourf of gridded DataArray output (like dx + spatial plot) +2. **Observation overlay** — scatter of input DataFrame observations overlaid on + assimilated grid output (unique to DA — shows sparse→dense mapping) +3. **Generator sequence** — multi-step assimilation evolution over time (for + stateful models) or repeated independent calls (for stateless) + +3d. Side-by-side comparison scripts — reference inference vs Earth2Studio equivalent + +3e-3f. Run scripts, user confirms plots. + +**[CONFIRM — Sanity-Check & Comparison]** + +#### Step 4 — Branch, Commit and Open PR + +DA-specific PR template fields: + +- Model type: Stateless / Stateful +- Input format: DataFrame / DataArray / Mixed +- Output format: DataArray / DataFrame +- Observation schema: columns and types +- Grid specification: lat-lon / HRRR / HealPix / etc. +- Time tolerance: default value +- cudf/cupy support: Yes / No +- Reference comparison metrics table + +**[CONFIRM — Ready to Submit]** + +**Step 5 — Automated Code Review (Greptile)** +Same polling/triage/fix pattern as dx/px. + +**[CONFIRM — Review Triage]** + +**Reminders** — same DA-specific rules as create skill. + +--- + +## Reference Files + +| File | Purpose | +| --------------------------------------------------- | ------------------------------------------ | +| `.cursor/rules/e2s-013-assimilation-models.mdc` | Authoritative DA rules (577 lines) | +| `earth2studio/models/da/base.py` | AssimilationModel Protocol definition | +| `earth2studio/models/da/interp.py` | Simple stateless DA example (593 lines) | +| `earth2studio/models/da/sda_stormcast.py` | Complex stateful DA example (919 lines) | +| `earth2studio/models/da/utils.py` | DA utilities (validate, filter, convert) | +| `test/models/da/test_da_interp.py` | DA test patterns (359 lines) | +| `test/models/da/test_da_healda.py` | Complex DA test patterns with mocks | + +## CONFIRM Gates Summary + +### Create skill (6 gates) + +1. `[CONFIRM — Model Analysis]` +2. `[CONFIRM — Dependencies]` +3. `[CONFIRM — Skeleton]` +4. `[CONFIRM — Coordinates]` +5. `[CONFIRM — Forward Pass]` +6. `[CONFIRM — Model Loading]` + +### Validate skill (4 gates) + +1. `[CONFIRM — Package Test]` +2. `[CONFIRM — Sanity-Check & Comparison]` +3. `[CONFIRM — Ready to Submit]` +4. `[CONFIRM — Review Triage]` From 794e8d79d89bb74ed59c1a131c54740a04581fa2 Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <ngeneva@nvidia.com> Date: Mon, 13 Apr 2026 22:58:26 +0000 Subject: [PATCH 4/7] Ignore super powers --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b6c82f78d..77346fce1 100644 --- a/.gitignore +++ b/.gitignore @@ -99,6 +99,7 @@ _build docs/modules/generated docs/modules/backreferences docs/examples +docs/superpowers # Pytestmon .testmondata* From 8782a70a4861dcc43dba462c40bd19badf193008 Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <ngeneva@nvidia.com> Date: Mon, 13 Apr 2026 23:00:19 +0000 Subject: [PATCH 5/7] Ignore super powers --- .../2026-04-13-assimilation-model-skills.md | 582 ------------------ ...-04-13-assimilation-model-skills-design.md | 311 ---------- 2 files changed, 893 deletions(-) delete mode 100644 docs/superpowers/plans/2026-04-13-assimilation-model-skills.md delete mode 100644 docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md diff --git a/docs/superpowers/plans/2026-04-13-assimilation-model-skills.md b/docs/superpowers/plans/2026-04-13-assimilation-model-skills.md deleted file mode 100644 index d7161d94e..000000000 --- a/docs/superpowers/plans/2026-04-13-assimilation-model-skills.md +++ /dev/null @@ -1,582 +0,0 @@ -# Assimilation Model Skills Implementation Plan - -> **For agentic workers:** REQUIRED SUB-SKILL: Use -> superpowers:subagent-driven-development (recommended) or -> superpowers:executing-plans to implement this plan task-by-task. -> Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Create two Claude skills (`create-assimilation-wrapper` -and `validate-assimilation-wrapper`) that guide agents through -building and validating Earth2Studio data assimilation model -wrappers. - -**Architecture:** Each skill is a single Markdown file in -`.claude/skills/<name>/SKILL.md` with YAML frontmatter, numbered -steps with `[CONFIRM]` gates, complete code blocks, and a -Reminders section. The create skill covers Steps 0-8 -(implementation only), the validate skill covers Steps 1-5 -(testing, comparison, PR, Greptile). - -**Tech Stack:** Markdown skill files, Python code blocks -referencing Earth2Studio DA patterns (pandas/cudf DataFrames, -xarray DataArrays, FrameSchema, CoordSystem, send-protocol -generators). - ---- - -## Task 1: Create `create-assimilation-wrapper` SKILL.md - -**Files:** - -- Create: `.claude/skills/create-assimilation-wrapper/SKILL.md` - -This task creates the full skill file for guiding creation of DA -model wrappers. The file should be modeled on the existing -`create-diagnostic-wrapper/SKILL.md` structure but with -DA-specific content throughout. - -**Key reference files to read before writing:** - -- `.cursor/rules/e2s-013-assimilation-models.mdc` (577 lines) - — authoritative DA implementation rules -- `earth2studio/models/da/base.py` (188 lines) - — AssimilationModel Protocol -- `earth2studio/models/da/interp.py` (593 lines) - — simple stateless DA example -- `earth2studio/models/da/utils.py` - — DA utilities (validate, filter, convert) -- `.claude/skills/create-diagnostic-wrapper/SKILL.md` - — structural template (layout, step structure, CONFIRM gates) -- `docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md` - — approved spec - -**Structure to follow:** - -The file must contain: - -1. **YAML Frontmatter** (lines 1-5): - - ```yaml - --- - name: create-assimilation-wrapper - description: >- - Create a new Earth2Studio data assimilation model (da) - wrapper from a reference inference script or repository. - DA models ingest sparse observations (DataFrames) and/or - gridded state arrays (DataArrays) and produce analysis - output — they do NOT use @batch_func, @batch_coords, - or PrognosticMixin. - argument-hint: >- - URL or local path to reference inference script/repo - (optional — will be asked if not provided) - --- - ``` - -2. **Title and preamble** — "Create Assimilation Model Wrapper", - CONFIRM gate explanation, uv/venv environment note. - -3. **Critical Differences from px/dx table** — The full table - from the spec (18 rows). This MUST be placed prominently - near the top, right after the preamble. It's the most - important reference for any agent using this skill. - -4. **Step 0 — Obtain Reference & Analyze Model** - - 0a: Accept `$ARGUMENTS` or ask user (same as dx skill) - - 0b: Analyze the reference to determine: - - Input types: DataFrame only, DataArray only, or mixed - - Output types: DataArray (gridded), DataFrame (tabular) - - Stateful vs stateless (state between steps?) - - Whether `@torch.inference_mode()` is safe - - Dependency packages - - No NN/physics branching — all DA models get the same flow - - Present analysis summary to user - - `[CONFIRM — Model Analysis]` - -5. **Step 1 — Examine Reference & Propose Dependencies** - - 1a: Identify packages (physicsnemo, scipy, healpy, etc.) - - 1b: Propose pyproject.toml group `da-<model-name>` - - Highlight cudf/cupy as optional GPU acceleration - - `[CONFIRM — Dependencies]` - -6. **Step 2 — Add Dependencies to pyproject.toml** - -7. **Step 3 — Create Skeleton Class File** - - File path: `earth2studio/models/da/<filename>.py` - - License header (SPDX Apache-2.0, 2024-2026) - - Dual inheritance: `torch.nn.Module + AutoModelMixin` - (NO PrognosticMixin) - - Class-level `@check_optional_dependencies()` decorator - - Optional dependency try/except with - `OptionalDependencyFailure("da-<name>")` - - cupy/cudf optional imports - - **Canonical DA method ordering** (MUST be enforced): - 1. `__init__` — register `device_buffer`, store model, - normalize time tolerance - 2. `device` property — `return self.device_buffer.device` - 3. `init_coords` — `None` for stateless, - `tuple[CoordSystem]` for stateful - 4. `input_coords` — `tuple[FrameSchema]` for DataFrame, - `tuple[CoordSystem]` for DataArray - 5. `output_coords` — accept `input_coords` tuple + - `request_time` kwarg, return tuple - 6. `load_default_package` — classmethod - 7. `load_model` — classmethod with - `@check_optional_dependencies()` - 8. `to` — device management, return `AssimilationModel` - 9. Private/support methods - 10. `__call__` — stateless forward - 11. `create_generator` — bidirectional generator with send - - **Complete skeleton code block** — must include all methods - with pseudocode `# TODO` comments. Use the complete example - from `e2s-013-assimilation-models.mdc` lines 420-544 as - the primary template. - - **Key differences from dx/px skeletons** to document: - - - No `@batch_func()`, no `@batch_coords()`, - no `PrognosticMixin` - - `input_coords` and `output_coords` return **tuples** - (even for single) - - `FrameSchema` for tabular inputs, - `CoordSystem` for gridded outputs - - `device` property instead of - `next(self.parameters()).device` - - `create_generator` with send protocol instead of - `create_iterator` - - `validate_observation_fields()` instead of - `handshake_coords`/`handshake_dim` for DataFrames - - `request_time` from `obs.attrs`, not a coordinate dim - - - `[CONFIRM — Skeleton]` - -8. **Step 4 — Implement Coordinate System** - - 4a: `init_coords()` — return `None` for stateless, - or tuple of `CoordSystem`/`FrameSchema` for stateful. - Show both patterns. - - 4b: `input_coords()` — return tuple of `FrameSchema`. - Show the standard 5-column schema: time, lat, lon, - observation, variable. Explain `np.empty(0, dtype=...)` - for unbounded vs `np.array([...])` for enumerated. - - 4c: `output_coords()` — accept tuple + `request_time`, - return tuple of `CoordSystem`. Show handshake helpers - only for `CoordSystem` inputs. Show - `validate_observation_fields()` for `FrameSchema`. - - `[CONFIRM — Coordinates]` - -9. **Step 5 — Implement Forward Pass** - - 5a: `__call__` — extract `request_time` from - `obs.attrs`, validate with - `validate_observation_fields()`, filter with - `filter_time_range()`, convert with - `dfseries_to_torch()`, run model, build - `xr.DataArray` with cupy/numpy based on device - - 5b: `create_generator` — two patterns: - - Stateless: `observations = yield None`, loop calling - `self.__call__(observations)`, handle `GeneratorExit` - - Stateful: accept init args, - `obs = yield initial_state`, loop with state updates, - handle `GeneratorExit` - - Code blocks for both patterns - - `[CONFIRM — Forward Pass]` - -10. **Step 6 — Implement Model Loading** - - `load_default_package` — Package with cache options - - `load_model` — `@check_optional_dependencies()`, - `package.resolve()`, `.eval()`, `.requires_grad_(False)` - - `.to()` — `super().to(device)`, return - `AssimilationModel` - - `[CONFIRM — Model Loading]` - -11. **Step 7 — Register in `__init__.py`** - - Add to `earth2studio/models/da/__init__.py` (alpha order) - - Add to `__all__` if present - -12. **Step 8 — Verify Style/Format/Lint** - - `make format`, `make lint`, `make license` - - Common mypy issues for DA: DataFrame type annotations, - optional cudf types - -13. **Reminders section** — Complete DA-specific DO/DON'T rules: - - DO return tuples from `input_coords`/`output_coords` - - DO use `FrameSchema` for tabular, `CoordSystem` for grid - - DO validate `request_time` from `obs.attrs` - - DO use `validate_observation_fields()`, - `filter_time_range()`, `dfseries_to_torch()` - - DO prime generator with `yield None` and handle - `GeneratorExit` - - DO return cupy arrays on GPU, numpy on CPU - - DO register `device_buffer` and expose `device` property - - DO use `loguru.logger`, never `print()` - - DO ensure all public functions have full type hints - - DO run `make format && make lint` before finalizing - - DO use `uv run python` for all Python commands - - DO NOT use `@batch_func()` or `@batch_coords()` - - DO NOT use `PrognosticMixin` - - DO NOT use `create_iterator` — DA uses - `create_generator` with send protocol - - DO NOT assume tensor inputs — inputs are DFs/DataArrays - - DO NOT forget cudf/cupy optional import pattern - - DO NOT make a general base class with intent to reuse - - DO NOT over-populate `load_model()` API - -- [ ] **Step 1: Read all reference files** - -Read: - -- `.cursor/rules/e2s-013-assimilation-models.mdc` (577 lines) -- `earth2studio/models/da/base.py` (188 lines) -- `earth2studio/models/da/interp.py` (593 lines — primary ref) -- `earth2studio/models/da/utils.py` (~175 lines) -- `.claude/skills/create-diagnostic-wrapper/SKILL.md` - (structural template) -- `docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md` - (297 lines) - -- [ ] **Step 2: Create directory** - -```bash -mkdir -p .claude/skills/create-assimilation-wrapper -``` - -- [ ] **Step 3: Write SKILL.md** - -Write `.claude/skills/create-assimilation-wrapper/SKILL.md` -with all content described above. The file should be -approximately 1200-1500 lines (similar to the dx create skill -at 1440 lines, but slightly shorter since there is no -NN/physics branching). - -**Self-review checklist before reporting:** - -- [ ] Frontmatter has correct name, description, argument-hint -- [ ] Critical differences table is present near the top -- [ ] All 9 steps (0-8) are present with correct headers -- [ ] 6 CONFIRM gates: Model Analysis, Dependencies, - Skeleton, Coordinates, Forward Pass, Model Loading -- [ ] Complete skeleton code block with all 11 method slots -- [ ] FrameSchema and CoordSystem patterns shown correctly -- [ ] `create_generator` with send protocol (both patterns) -- [ ] `validate_observation_fields`, `filter_time_range`, - `dfseries_to_torch` all shown -- [ ] cupy/cudf optional import and output patterns -- [ ] `device` property pattern (not - `next(self.parameters()).device`) -- [ ] No stale px/dx references (no `@batch_func`, - no `@batch_coords`, no `PrognosticMixin`, - no `create_iterator`, no `handshake_dim` for DF inputs) -- [ ] License header template (SPDX Apache-2.0, 2024-2026) -- [ ] Reminders section with all DO/DON'T rules from spec -- [ ] No placeholder TODOs in prose (skeleton code blocks may - have `# TODO` pseudocode) - -- [ ] **Step 4: Self-review and report** - -Report status: DONE / DONE_WITH_CONCERNS / BLOCKED / -NEEDS_CONTEXT - ---- - -## Task 2: Create `validate-assimilation-wrapper` SKILL.md - -**Files:** - -- Create: `.claude/skills/validate-assimilation-wrapper/SKILL.md` - -This task creates the full validation skill file. Model on the -existing `validate-diagnostic-wrapper/SKILL.md` structure but -with DA-specific content. - -**Key reference files to read before writing:** - -- `.claude/skills/validate-diagnostic-wrapper/SKILL.md` - (831 lines — structural template) -- `.cursor/rules/e2s-013-assimilation-models.mdc` (577 lines) -- `test/models/da/test_da_interp.py` (359 lines — test patterns) -- `earth2studio/models/da/interp.py` (593 lines — reference) -- `docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md` - (297 lines) - -**Structure to follow:** - -1. **YAML Frontmatter**: - - ```yaml - --- - name: validate-assimilation-wrapper - description: >- - Validate a newly created Earth2Studio data assimilation - model wrapper by writing unit tests (90% coverage - required), performing reference comparison with - DataFrame/DataArray outputs, generating sanity-check - plots, and opening a PR with automated code review. - Use after completing create-assimilation-wrapper - Steps 0-8. - argument-hint: >- - Name of the DA model class and test file (optional — - will be inferred from recent changes if not provided) - --- - ``` - -2. **Title and preamble** — uv/venv note, CONFIRM gate - explanation. - -3. **Step 1 — Write Pytest Unit Tests** - - DA-specific test patterns (NOT the same as dx/px tests): - - - `PhooModelName` dummy class — a simple - `torch.nn.Module` that returns known output shapes as - `xr.DataArray`. It must accept `pd.DataFrame` input - and return `xr.DataArray` output (NOT tensor-to-tensor - like dx/px dummies). - - `test_package` fixture — create dummy checkpoint, - save to tmp_path - - **Parametrize over pandas AND cudf** — - `@pytest.fixture` for pandas DataFrame, separate - fixture for cudf with `pytest.skip` guard - - **Parametrize over CPU AND GPU** devices - - `test_model_call` — create DataFrame with obs.attrs, - call model, verify: - - Output is `xr.DataArray` (not torch.Tensor) - - Output dims match output_coords - - Output data type matches device (cupy/numpy) - - `test_generator_protocol` — prime, send, close: - - ```python - gen = model.create_generator() - result = gen.send(None) # Prime - assert result is None # or initial state - da = gen.send(obs_df) # Send observations - assert isinstance(da, xr.DataArray) - gen.close() - ``` - - - `test_init_coords` — verify returns None (stateless) - or tuple (stateful) - - `test_input_coords` — verify returns tuple of - FrameSchema - - `test_time_tolerance` — verify filter_time_range works - - `test_empty_dataframe` — verify graceful handling - - `test_invalid_attrs` — verify raises on missing - `request_time` - - `test_validate_observation_fields` — verify raises on - bad columns - - `test_model_exceptions` — invalid inputs raise - ValueError/KeyError - - `@pytest.mark.package` integration test - - Show complete test file template with all test methods. - - - `[CONFIRM — Package Test]` (for integration test only) - -4. **Step 2 — Run Tests & Achieve 90% Coverage** - - 2a: Run test file with `-v --timeout=60` - - 2b: Coverage with `--slow` - `--cov=earth2studio/models/da/<filename>` - `--cov-report=term-missing` `--cov-fail-under=90` - - DA-specific coverage gaps: GeneratorExit cleanup path, - cudf code paths, empty DataFrame handling, time - tolerance edge cases, obs.attrs validation branches, - cupy vs numpy output paths - - 2c: Optional full suite `make pytest TOX_ENV=test-da` - -5. **Step 3 — Reference Comparison & Sanity-Check** - - 3a: Reference comparison script — compare `__call__` - output AND multi-step generator output against - reference implementation - - For DataArray: max_abs_diff, max_rel_diff, - correlation, torch.allclose - - For DataFrame: row counts, value ranges, spatial - coverage - - 3b: Model summary table — input schema, output grid, - variables, stateful/stateless, observation types, - cudf support - - 3c: Three sanity-check plot templates: - 1. **Spatial assimilated output** — `contourf` of - gridded DataArray output - 2. **Observation overlay** — scatter of input DataFrame - observations on assimilated grid (unique to DA) - 3. **Generator sequence** — multi-step evolution - - 3d: Side-by-side comparison scripts (ref vs E2S) - - 3e-3f: Run scripts, user confirms plots - - `[CONFIRM — Sanity-Check & Comparison]` - -6. **Step 4 — Branch, Commit & Open PR** - - `[CONFIRM — Ready to Submit]` with DA checklist - - 4a: Create branch `feat/da-model-<name>`, commit - - 4b: Identify fork remote, push - - 4c: Open PR with DA-specific template: - - Model type: Stateless / Stateful - - Input format: DataFrame / DataArray / Mixed - - Output format: DataArray / DataFrame - - Observation schema: columns and types - - Grid specification - - Time tolerance - - cudf/cupy support - - Reference comparison metrics - - 4d: Post sanity-check as PR comment - -7. **Step 5 — Automated Code Review (Greptile)** - - Same polling/triage/fix pattern as dx/px validate skills - - 5a: Poll for `greptile-apps[bot]` review - - 5b: Fetch + parse comments - - 5c: Categorize and present triage table - - `[CONFIRM — Review Triage]` - - 5d: Implement fixes - - 5e: Respond to comments - - 5f: Push and resolve - -8. **Reminders section** — Same DA-specific DO/DON'T as - create skill, plus: - - DO NOT commit sanity-check scripts/images - - DO NOT commit secrets/credentials - - DO maintain alphabetical order in `__init__.py` exports - - NEVER call `loop.set_default_executor()` - -- [ ] **Step 1: Read all reference files** - -Read: - -- `.claude/skills/validate-diagnostic-wrapper/SKILL.md` - (831 lines) -- `.cursor/rules/e2s-013-assimilation-models.mdc` (577 lines) -- `test/models/da/test_da_interp.py` (359 lines) -- `docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md` - (297 lines) - -- [ ] **Step 2: Create directory** - -```bash -mkdir -p .claude/skills/validate-assimilation-wrapper -``` - -- [ ] **Step 3: Write SKILL.md** - -Write `.claude/skills/validate-assimilation-wrapper/SKILL.md` -with all content described above. Target approximately -900-1100 lines (similar to dx validate at 831 lines, but -slightly longer due to DA-specific test templates). - -**Self-review checklist before reporting:** - -- [ ] Frontmatter has correct name, description, argument-hint -- [ ] All 5 steps present with correct headers -- [ ] 4 CONFIRM gates: Package Test, Sanity-Check & - Comparison, Ready to Submit, Review Triage -- [ ] Complete test file template with DA-specific patterns - (DataFrame fixtures, cudf parametrize, generator protocol - test, obs.attrs test) -- [ ] 90% coverage requirement with DA-specific gap list -- [ ] 3 sanity-check plot templates (spatial, observation - overlay, generator sequence) -- [ ] Reference comparison script with DataArray AND - DataFrame metrics -- [ ] DA-specific PR template with all required fields -- [ ] Greptile polling/triage workflow -- [ ] Reminders section with all DA-specific rules -- [ ] No stale dx/px references (no `@batch_func`, no tensor - assertions, no `handshake_dim` for DataFrame assertions) -- [ ] No data-source leftovers - -- [ ] **Step 4: Self-review and report** - -Report status: DONE / DONE_WITH_CONCERNS / BLOCKED / -NEEDS_CONTEXT - ---- - -## Task 3: Final Verification and Commit - -**Files:** - -- Verify: `.claude/skills/create-assimilation-wrapper/SKILL.md` -- Verify: `.claude/skills/validate-assimilation-wrapper/SKILL.md` - -**Depends on:** Tasks 1 and 2 - -- [ ] **Step 1: Verify both files exist and are non-empty** - -```bash -ls -la .claude/skills/create-assimilation-wrapper/SKILL.md -ls -la .claude/skills/validate-assimilation-wrapper/SKILL.md -wc -l .claude/skills/create-assimilation-wrapper/SKILL.md -wc -l .claude/skills/validate-assimilation-wrapper/SKILL.md -``` - -Expected: create skill ~1200-1500 lines, validate ~900-1100. - -- [ ] **Step 2: Verify frontmatter is parseable** - -```bash -head -6 .claude/skills/create-assimilation-wrapper/SKILL.md -head -6 .claude/skills/validate-assimilation-wrapper/SKILL.md -``` - -Both must start with `---` and have `name:`, `description:`, -`argument-hint:`. - -- [ ] **Step 3: Check for stale placeholders** - -```bash -grep -n "TODO" \ - .claude/skills/create-assimilation-wrapper/SKILL.md \ - | grep -v "# TODO" | head -20 -grep -n "TODO" \ - .claude/skills/validate-assimilation-wrapper/SKILL.md \ - | grep -v "# TODO" | head -20 -grep -n "TBD" \ - .claude/skills/create-assimilation-wrapper/SKILL.md -grep -n "TBD" \ - .claude/skills/validate-assimilation-wrapper/SKILL.md -``` - -No stale TODOs outside of skeleton code blocks. No TBDs. - -- [ ] **Step 4: Verify no stale px/dx references** - -```bash -grep -n \ - "@batch_func\|@batch_coords\|PrognosticMixin" \ - .claude/skills/create-assimilation-wrapper/SKILL.md \ - | head -20 -grep -n \ - "@batch_func\|@batch_coords\|PrognosticMixin" \ - .claude/skills/validate-assimilation-wrapper/SKILL.md \ - | head -20 -``` - -Any matches must be in "DO NOT" / negative context only. - -- [ ] **Step 5: Cross-reference spec coverage** - -Read spec file and verify every requirement has -corresponding content in the skill files: - -- 6 CONFIRM gates in create skill -- 4 CONFIRM gates in validate skill -- Critical differences table in create skill -- All Reminders items present - -- [ ] **Step 6: Commit** - -```bash -git add \ - .claude/skills/create-assimilation-wrapper/SKILL.md \ - .claude/skills/validate-assimilation-wrapper/SKILL.md -git commit -m "feat: add create and validate assimilation wrapper skills - -Add create-assimilation-wrapper (Steps 0-8) and -validate-assimilation-wrapper (Steps 1-5) skills for -guiding DA model wrapper implementation and validation. -DA-specific patterns: FrameSchema, send-protocol generators, -cudf/cupy support, validate_observation_fields." -``` - -- [ ] **Step 7: Verify commit** - -```bash -git log --oneline -1 -git status -``` diff --git a/docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md b/docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md deleted file mode 100644 index c25468ca5..000000000 --- a/docs/superpowers/specs/2026-04-13-assimilation-model-skills-design.md +++ /dev/null @@ -1,311 +0,0 @@ -# Assimilation Model Skills Design - -## Goal - -Create two Claude skills for Earth2Studio data assimilation (DA) models: - -1. `create-assimilation-wrapper` — guide creation of a new DA model wrapper -2. `validate-assimilation-wrapper` — guide validation, PR, and code review - -These follow the established create+validate pattern from the diagnostic (dx) and -prognostic (px) skill pairs, but with **fundamental I/O differences** that make DA -models distinct. - -## Critical Differences from px/dx Models - -DA models are **not** tensor-in / tensor-out like px and dx models. The entire I/O -contract is different: - -| Aspect | px / dx | da | -| ------------------- | ------------------------------ | ---------------------------------------- | -| **Primary input** | `Tensor` + `CoordSystem` | `pd.DataFrame` or `xr.DataArray` | -| **Primary output** | `Tensor` + `CoordSystem` | `xr.DataArray` or `pd.DataFrame` | -| **Coord schemas** | `CoordSystem` only | `FrameSchema` + `CoordSystem` | -| **Coord returns** | Single dict | **Tuple** (even for single) | -| **Batch handling** | `@batch_func` / `@batch_coords` | **Neither** — N/A for DataFrame I/O | -| **PrognosticMixin** | Yes (px) / No (dx) | **No** | -| **Time integration** | `create_iterator` yields | `create_generator` with **send** | -| **Generator prime** | Yields initial condition | Yields `None` / state, `.send()` | -| **Generator cleanup** | N/A | Must handle `GeneratorExit` | -| **Init data** | No | Optional via `init_coords()` | -| **Time metadata** | Coordinate dimension | `obs.attrs["request_time"]` | -| **GPU data** | `Tensor` on device | **cupy** arrays, **cudf** DFs | -| **Input validation** | `handshake_dim`/`handshake_coords` | `validate_observation_fields()` | -| **Time filtering** | N/A | `filter_time_range()` | -| **Tensor conversion** | Already tensors | `dfseries_to_torch()` | -| **Device tracking** | `device_buffer` / `parameters()` | `device_buffer` + `@property` | - -These differences pervade every step of the skill — skeleton, coordinates, forward -pass, generator, testing, and validation. - -## Approach - -Single template for the general **stateful** case, with clear notes on simplifications -for **stateless** models (those returning `None` from `init_coords()`). No NN/physics -branching — DA models are all NN-backed or algorithmic with external packages. - -Reference input: user provides reference script/repo URL or local path (same as px -skill). - ---- - -## Skill 1: `create-assimilation-wrapper` - -### Frontmatter (create) - -```yaml ---- -name: create-assimilation-wrapper -description: >- - Create a new Earth2Studio data assimilation model (da) wrapper from a - reference inference script or repository. DA models ingest sparse - observations (DataFrames) and/or gridded state arrays (DataArrays) and - produce analysis output — they do NOT use @batch_func, @batch_coords, - or PrognosticMixin. -argument-hint: URL or local path to reference inference script/repo ---- -``` - -### Steps 0-8 - -The create skill handles **implementation only** — skeleton, coordinates, forward -pass, model loading, registration, and code quality. All testing, validation, -comparison scripts, and PR submission live in the validate skill. - -#### Step 0 — Obtain Reference Script - -- Accept `$ARGUMENTS` or ask user for reference -- Analyze: detect input types (DataFrame? DataArray?), output types, whether - stateful (needs init data) or stateless, dependency packages, model architecture -- Determine if `@torch.inference_mode()` is safe (no gradient flow needed) -- Present summary to user - -No NN/physics branching — all DA models get the same flow. - -**[CONFIRM — Model Analysis]** - -#### Step 1 — Examine Reference and Propose Dependencies - -- Identify Python packages (physicsnemo, scipy, healpy, cudf, cupy, etc.) -- Propose pyproject.toml dependency group named `da-<model-name>` -- Highlight cudf/cupy as optional GPU acceleration packages - -**[CONFIRM — Dependencies]** - -#### Step 2 — Add Dependencies to pyproject.toml - -#### Step 3 — Create Skeleton Class File - -Dual inheritance: `torch.nn.Module + AutoModelMixin` (NO PrognosticMixin). -Class-level `@check_optional_dependencies()` decorator. - -**Canonical DA method ordering:** - -1. `__init__` — register `device_buffer`, store model params, normalize tolerance -2. `device` property — `return self.device_buffer.device` -3. `init_coords` — `None` for stateless, tuple for stateful -4. `input_coords` — tuple of `FrameSchema` (DF) or `CoordSystem` (DA) -5. `output_coords` — accept `input_coords` tuple + `request_time`, return tuple -6. `load_default_package` — classmethod -7. `load_model` — classmethod with `@check_optional_dependencies()` -8. `to` — device management, return `AssimilationModel` -9. Private/support methods (e.g., `_interpolate`, `_forward`, spatial lookups) -10. `__call__` — stateless forward, accept `*args: pd.DataFrame | xr.DataArray | None` -11. `create_generator` — bidirectional generator with send protocol - -**DA-specific skeleton elements:** - -- `FrameSchema` for observation inputs (time, lat, lon, observation, variable) -- `CoordSystem` for gridded outputs -- `request_time` from `obs.attrs`, NOT a coordinate dimension -- `validate_observation_fields()` call in `__call__` -- `filter_time_range()` for time-window filtering -- `dfseries_to_torch()` for zero-copy DataFrame→tensor -- cupy/cudf support: `cp.asarray()` for GPU output, `.cpu().numpy()` for CPU -- Generator: `yield None` to prime, `observations = yield result`, handle `GeneratorExit` -- `@torch.inference_mode()` unless gradient flow is required (document reason if omitted) -- No `@batch_func()`, no `@batch_coords()`, no `PrognosticMixin` - -**[CONFIRM — Skeleton]** - -#### Step 4 — Implement Coordinate System - -Key differences from px/dx: - -- `init_coords()` returns `None` for stateless models or tuple for stateful -- `input_coords()` returns **tuple** of `FrameSchema` — one per `__call__` arg -- `FrameSchema` keys are DF column names (time, lat, lon, observation, variable) -- `output_coords()` accepts tuple + `request_time`, returns tuple of `CoordSystem` -- Use `handshake_dim`/`handshake_coords`/`handshake_size` only for `CoordSystem` -- Use `validate_observation_fields()` for `FrameSchema` inputs - -**[CONFIRM — Coordinates]** - -#### Step 5 — Implement Forward Pass - -Two methods: - -- `__call__`: Extract `request_time` from `obs.attrs`, validate with - `validate_observation_fields()`, filter with `filter_time_range()`, convert with - `dfseries_to_torch()`, run model, build `xr.DataArray` output with cupy/numpy -- `create_generator`: Prime with `yield None` (stateless) or `yield initial_state` - (stateful), loop with `observations = yield result`, handle `GeneratorExit` - -**[CONFIRM — Forward Pass]** - -**Step 6 — Implement Model Loading** -`load_default_package`, `load_model` with `@check_optional_dependencies()` - -**[CONFIRM — Model Loading]** - -**Step 7 — Register in `__init__.py`** -Add to `earth2studio/models/da/__init__.py` - -**Step 8 — Verify Style/Format/Lint** -`make format`, `make lint`, `make license` - -The create skill ends here. All testing, validation, comparison, and PR work is -handled by the `validate-assimilation-wrapper` skill. - -**Reminders section** — DA-specific DO/DON'T: - -- DO return tuples from `input_coords` and `output_coords` -- DO use `FrameSchema` for tabular inputs, `CoordSystem` for gridded outputs -- DO validate `request_time` from `obs.attrs` -- DO use `validate_observation_fields()`, `filter_time_range()`, `dfseries_to_torch()` -- DO prime generator with `yield None` and handle `GeneratorExit` -- DO return cupy arrays on GPU, numpy on CPU -- DO register `device_buffer` and expose `device` property -- DO NOT use `@batch_func()` or `@batch_coords()` -- DO NOT use `PrognosticMixin` -- DO NOT use `create_iterator` — DA uses `create_generator` with send protocol -- DO NOT assume tensor inputs — inputs are DataFrames/DataArrays -- DO NOT forget cudf/cupy optional import pattern - ---- - -## Skill 2: `validate-assimilation-wrapper` - -### Frontmatter (validate) - -```yaml ---- -name: validate-assimilation-wrapper -description: >- - Validate a newly created Earth2Studio data assimilation model wrapper by - writing unit tests (90% coverage required), performing reference comparison - with DataFrame/DataArray outputs, generating sanity-check plots, and - opening a PR with automated code review. Use after completing - create-assimilation-wrapper Steps 0-8. -argument-hint: Name of the DA model class and test file ---- -``` - -### Steps 1-6 - -#### Step 1 — Write Pytest Unit Tests - -DA-specific test patterns: - -- PhooModelName dummy that returns known output shapes -- **Parametrize over pandas AND cudf** (skip cudf if unavailable) -- **Parametrize over CPU AND GPU** devices -- Test `__call__` with DataFrame input, verify `xr.DataArray` output -- Test generator protocol: prime → send → close -- Test `init_coords` returns correct type (None or tuple) -- Test time tolerance filtering -- Test empty DataFrame handling -- Test invalid `obs.attrs` (missing `request_time`) -- Test `validate_observation_fields` raises on bad columns -- `@pytest.mark.package` integration test - -**[CONFIRM — Package Test]** (for `@pytest.mark.package` integration test only) - -#### Step 2 — Run Tests and Achieve 90% Coverage - -- Run test file with `-v --timeout=60`, all must pass -- Coverage with `--slow`, `--cov`, **`--cov-fail-under=90`** — the new DA model - file must achieve **at least 90% line coverage** -- DA-specific coverage gaps to watch: generator cleanup path (`GeneratorExit`), - cudf code paths, empty DataFrame handling, time tolerance edge cases, - `obs.attrs` validation branches, cupy vs numpy output paths -- If coverage is below 90%, add tests to cover missing lines and re-run - -#### Step 3 — Reference Comparison and Sanity-Check - -3a. Reference comparison — compare `__call__` output AND multi-step generator -output against reference implementation. For DataArray output: compute tolerance -metrics (max abs diff, correlation, allclose). For DataFrame output: compare row -counts, value ranges, spatial coverage. - -3b. Model summary table — input schema, output grid, variables, stateful/stateless, -observation types, cudf support. - -3c. Three sanity-check plot templates: - -1. **Spatial assimilated output** — contourf of gridded DataArray output (like dx - spatial plot) -2. **Observation overlay** — scatter of input DataFrame observations overlaid on - assimilated grid output (unique to DA — shows sparse→dense mapping) -3. **Generator sequence** — multi-step assimilation evolution over time (for - stateful models) or repeated independent calls (for stateless) - -3d. Side-by-side comparison scripts — reference inference vs Earth2Studio equivalent - -3e-3f. Run scripts, user confirms plots. - -**[CONFIRM — Sanity-Check & Comparison]** - -#### Step 4 — Branch, Commit and Open PR - -DA-specific PR template fields: - -- Model type: Stateless / Stateful -- Input format: DataFrame / DataArray / Mixed -- Output format: DataArray / DataFrame -- Observation schema: columns and types -- Grid specification: lat-lon / HRRR / HealPix / etc. -- Time tolerance: default value -- cudf/cupy support: Yes / No -- Reference comparison metrics table - -**[CONFIRM — Ready to Submit]** - -**Step 5 — Automated Code Review (Greptile)** -Same polling/triage/fix pattern as dx/px. - -**[CONFIRM — Review Triage]** - -**Reminders** — same DA-specific rules as create skill. - ---- - -## Reference Files - -| File | Purpose | -| --------------------------------------------------- | ------------------------------------------ | -| `.cursor/rules/e2s-013-assimilation-models.mdc` | Authoritative DA rules (577 lines) | -| `earth2studio/models/da/base.py` | AssimilationModel Protocol definition | -| `earth2studio/models/da/interp.py` | Simple stateless DA example (593 lines) | -| `earth2studio/models/da/sda_stormcast.py` | Complex stateful DA example (919 lines) | -| `earth2studio/models/da/utils.py` | DA utilities (validate, filter, convert) | -| `test/models/da/test_da_interp.py` | DA test patterns (359 lines) | -| `test/models/da/test_da_healda.py` | Complex DA test patterns with mocks | - -## CONFIRM Gates Summary - -### Create skill (6 gates) - -1. `[CONFIRM — Model Analysis]` -2. `[CONFIRM — Dependencies]` -3. `[CONFIRM — Skeleton]` -4. `[CONFIRM — Coordinates]` -5. `[CONFIRM — Forward Pass]` -6. `[CONFIRM — Model Loading]` - -### Validate skill (4 gates) - -1. `[CONFIRM — Package Test]` -2. `[CONFIRM — Sanity-Check & Comparison]` -3. `[CONFIRM — Ready to Submit]` -4. `[CONFIRM — Review Triage]` From c9c8c208338be9ca43d5c304346c6453ccd3c758 Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <ngeneva@nvidia.com> Date: Mon, 13 Apr 2026 23:10:19 +0000 Subject: [PATCH 6/7] Adding docs --- docs/userguide/developer/agents.md | 118 +++++++++++++++++++++++++++++ docs/userguide/developer/index.md | 1 + docs/userguide/index.md | 1 + 3 files changed, 120 insertions(+) create mode 100644 docs/userguide/developer/agents.md diff --git a/docs/userguide/developer/agents.md b/docs/userguide/developer/agents.md new file mode 100644 index 000000000..59b7a5889 --- /dev/null +++ b/docs/userguide/developer/agents.md @@ -0,0 +1,118 @@ +<!-- markdownlint-disable MD025 --> + +(developer_agents)= + +# AI Agent Skills + +Earth2Studio ships a set of **agent skills** that guide AI coding assistants through +common development workflows step by step. +Each skill is a structured Markdown document that lives in `.claude/skills/` and can be +invoked by compatible agents (Claude Code, OpenCode, Cursor, etc.) to scaffold new +components, run validation, and open pull requests. + +:::{note} +Agent skills are not a substitute for understanding the codebase. +They encode the project's conventions and review expectations so that an AI assistant +can follow them reliably, but a human developer should still review the output. +::: + +## Available Skills + +The skills are organised in **create / validate** pairs. +The *create* skill walks through implementation from a reference script to a working +wrapper; the *validate* skill picks up where *create* leaves off and handles testing, +comparison, and PR submission. + +| Skill | Purpose | +| ----- | ------- | +| `create-data-source` | Wrap a remote data store as a data source | +| `create-prognostic-wrapper` | Wrap a prognostic (time-stepping) model as a PrognosticModel | +| `validate-prognostic-wrapper` | Test, compare, and submit a new prognostic wrapper | +| `create-diagnostic-wrapper` | Wrap a diagnostic (single-step) model as a DiagnosticModel | +| `validate-diagnostic-wrapper` | Test, compare, and submit a new diagnostic wrapper | +| `create-assimilation-wrapper` | Wrap a data assimilation model as an AssimilationModel | +| `validate-assimilation-wrapper` | Test, compare, and submit a new assimilation wrapper | + +## How Skills Work + +Each skill document is a numbered step-by-step guide with **confirmation gates**. +At every gate the agent pauses and asks the developer to review what has been produced +before continuing. + +A typical *create* skill follows this flow: + +1. **Obtain a reference** -- the developer provides a URL or path to existing inference + code. +2. **Analyse dependencies** -- the agent reads the reference, proposes a + `pyproject.toml` extras group, and waits for confirmation. +3. **Create a skeleton** -- a class file is generated with the correct inheritance, + method ordering, and pseudocode bodies. +4. **Implement coordinates, forward pass, model loading** -- each section has its own + confirmation gate. +5. **Register, format, lint** -- the wrapper is added to `__init__.py` and checked with + `make format && make lint`. + +A typical *validate* skill continues with: + +1. **Run tests** -- unit tests are executed and coverage must reach 90 %. +2. **Reference comparison** -- the agent produces a side-by-side script comparing the + new wrapper against the original inference code and generates sanity-check plots. +3. **Open a PR** -- the agent creates a branch, commits, and opens a pull request to + `NVIDIA/earth2studio` with a standardised template. +4. **Greptile review** -- the agent polls for the automated code review, triages + feedback, and applies accepted fixes. + +## Using a Skill + +### Claude Code / OpenCode + +Invoke a skill directly by name or let the agent auto-detect it from context: + +```text +> /create-prognostic-wrapper https://github.com/example/model/infer.py +``` + +The agent will load `.claude/skills/create-prognostic-wrapper/SKILL.md` and begin +Step 0. +You can also describe what you want in plain language and the agent will select the +appropriate skill: + +```text +> I want to add the FourCastNet model as a new prognostic wrapper. +``` + +### Cursor + +The skills are referenced by the corresponding `.cursor/rules/` rule files. +When working inside a file that matches a rule's glob pattern (e.g., +`earth2studio/models/px/*.py`), Cursor will surface the relevant conventions +automatically. + +## Relationship to Cursor Rules + +The `.cursor/rules/` directory contains concise rule files that enforce coding +conventions (method ordering, decorators, coordinate validation, etc.). +Agent skills **build on** these rules: they reference the same conventions but add the +full workflow -- dependency management, skeleton creation, testing, PR submission, and +automated review handling. + +| Resource | Location | Scope | +| -------- | -------- | ----- | +| Cursor rules | `.cursor/rules/e2s-*.mdc` | Conventions and style enforcement | +| Agent skills | `.claude/skills/*/SKILL.md` | End-to-end development workflows | +| Slash commands | `.claude/commands/*.md` | One-shot tasks (format, lint, test, docs) | +| CLAUDE.md | `CLAUDE.md` | Quick-reference entry point for agents | + +## Writing a New Skill + +If you need to create a skill for a new component type: + +1. Create a directory under `.claude/skills/<skill-name>/`. +2. Add a `SKILL.md` file with YAML frontmatter (`name`, `description`, + `argument-hint`). +3. Structure the document as numbered steps with `[CONFIRM -- <Title>]` gates at + decision points. +4. Reference the matching `.cursor/rules/` file for coding conventions. +5. Include a **Reminders** section at the end with DO / DO NOT rules. + +Look at an existing skill (e.g., `create-prognostic-wrapper`) as a template. diff --git a/docs/userguide/developer/index.md b/docs/userguide/developer/index.md index 3befd4d92..cd429b857 100644 --- a/docs/userguide/developer/index.md +++ b/docs/userguide/developer/index.md @@ -10,4 +10,5 @@ documentation testing build recipes +agents ``` diff --git a/docs/userguide/index.md b/docs/userguide/index.md index 6dbf74309..887a72d82 100644 --- a/docs/userguide/index.md +++ b/docs/userguide/index.md @@ -68,6 +68,7 @@ run(["2024-01-01"], 10, model, ds, io) - [Testing](developer/testing) - [Build](developer/build) - [Recipes](developer/recipes) +- [AI Agent Skills](developer/agents) ## Support From 4cf46846422a97b6da5be50ef0e7311fa7675fcf Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <ngeneva@nvidia.com> Date: Tue, 14 Apr 2026 16:24:04 +0000 Subject: [PATCH 7/7] Greptile --- .claude/skills/validate-assimilation-wrapper/SKILL.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.claude/skills/validate-assimilation-wrapper/SKILL.md b/.claude/skills/validate-assimilation-wrapper/SKILL.md index 45f785ced..453430d68 100644 --- a/.claude/skills/validate-assimilation-wrapper/SKILL.md +++ b/.claude/skills/validate-assimilation-wrapper/SKILL.md @@ -234,9 +234,8 @@ def test_model_call(sample_observations_pandas, device): assert da.coords["time"].values[0] == request_time[0] # Validate output shape matches model's coordinate system - n_variables = len(model.VARIABLES) if hasattr(model, "VARIABLES") else da.shape[1] + # NOTE: n_variables varies by model - update this check to match the model's output assert da.shape[0] == len(request_time) - assert da.shape[1] == n_variables # Validate coordinate values assert "t2m" in da.coords["variable"].values # At least one expected variable @@ -589,14 +588,14 @@ least 90% line coverage**: ```bash uv run python -m pytest test/models/da/test_<filename>.py -v \ - --slow --timeout=300 \ + --package --timeout=300 \ --cov=earth2studio/models/da/<filename> \ --cov-report=term-missing \ --cov-fail-under=90 ``` -- `--slow` enables integration tests marked with `@pytest.mark.package` - (the `--slow` flag is configured in `conftest.py` to include package +- `--package` enables integration tests marked with `@pytest.mark.package` + (the `--package` flag is configured in `conftest.py` to include package tests that download real checkpoints and may require GPU) - `--cov=earth2studio/models/da/<filename>` scopes coverage to the new model module only