Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions earth2studio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def deterministic(
output_coords: CoordSystem = OrderedDict({}),
device: torch.device | None = None,
verbose: bool = True,
compile: bool = False,
) -> IOBackend:
"""Built in deterministic workflow.
This workflow creates a determinstic inference pipeline to produce a forecast
Expand All @@ -68,6 +69,10 @@ def deterministic(
Device to run inference on, by default None
verbose : bool, optional
Print inference progress, by default True
compile : bool, optional
Use torch.compile to accelerate the prognostic model's forward pass,
by default False. Note: This requires PyTorch >= 2.0 and may incur
initial compilation overhead.

Returns
-------
Expand All @@ -84,6 +89,17 @@ def deterministic(
)
logger.info(f"Inference device: {device}")
prognostic = prognostic.to(device)

if compile:
if hasattr(prognostic, "_forward"):
logger.info("Compiling prognostic model...")
prognostic._forward = torch.compile(
prognostic._forward, mode="reduce-overhead"
)
else:
logger.warning(
"Compilation requested but prognostic model does not have _forward method. Skipping."
)
# sphinx - fetch data start
# Fetch data from data source and load onto device
prognostic_ic = prognostic.input_coords()
Expand Down Expand Up @@ -163,6 +179,7 @@ def diagnostic(
output_coords: CoordSystem = OrderedDict({}),
device: torch.device | None = None,
verbose: bool = True,
compile: bool = False,
) -> IOBackend:
"""Built in diagnostic workflow.
This workflow creates a determinstic inference pipeline that couples a prognostic
Expand All @@ -188,6 +205,10 @@ def diagnostic(
Device to run inference on, by default None
verbose : bool, optional
Print inference progress, by default True
compile : bool, optional
Use torch.compile to accelerate the prognostic and diagnostic model's
forward pass, by default False. Note: This requires PyTorch >= 2.0 and
may incur initial compilation overhead.

Returns
-------
Expand All @@ -205,6 +226,21 @@ def diagnostic(
logger.info(f"Inference device: {device}")
prognostic = prognostic.to(device)
diagnostic = diagnostic.to(device)

if compile:
if hasattr(prognostic, "_forward"):
logger.info("Compiling prognostic model...")
prognostic._forward = torch.compile(
prognostic._forward, mode="reduce-overhead"
)
else:
logger.warning(
"Compilation requested but prognostic model does not have _forward method. Skipping prognostic compilation."
)

# Compile the diagnostic model call
logger.info("Compiling diagnostic model...")
diagnostic = torch.compile(diagnostic, mode="reduce-overhead")
# Fetch data from data source and load onto device
prognostic_ic = prognostic.input_coords()
diagnostic_ic = diagnostic.input_coords()
Expand Down Expand Up @@ -290,6 +326,7 @@ def ensemble(
output_coords: CoordSystem = OrderedDict({}),
device: torch.device | None = None,
verbose: bool = True,
compile: bool = False,
) -> IOBackend:
"""Built in ensemble workflow.

Expand Down Expand Up @@ -318,6 +355,10 @@ def ensemble(
Device to run inference on, by default None
verbose : bool, optional
Print inference progress, by default True
compile : bool, optional
Use torch.compile to accelerate the prognostic model's forward pass,
by default False. Note: This requires PyTorch >= 2.0 and may incur
initial compilation overhead.

Returns
-------
Expand All @@ -336,6 +377,17 @@ def ensemble(
logger.info(f"Inference device: {device}")
prognostic = prognostic.to(device)

if compile:
if hasattr(prognostic, "_forward"):
logger.info("Compiling prognostic model...")
prognostic._forward = torch.compile(
prognostic._forward, mode="reduce-overhead"
)
else:
logger.warning(
"Compilation requested but prognostic model does not have _forward method. Skipping."
)

# Fetch data from data source and load onto device
prognostic_ic = prognostic.input_coords()
time = to_time_array(time)
Expand Down
104 changes: 104 additions & 0 deletions test/run/test_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
import torch

import earth2studio.run as run
from earth2studio.data import Random
from earth2studio.io import ZarrBackend
from earth2studio.models.px import Persistence


def test_deterministic_compile_flag():
"""Verify that the compile flag triggers torch.compile on the prognostic model's _forward method."""
coords = OrderedDict([("lat", np.arange(10)), ("lon", np.arange(20))])
variable = ["t2m"]
nsteps = 1
time = ["2024-01-01"]
device = "cpu"

data = Random(domain_coords=coords)
model = Persistence(variable, coords)
io = ZarrBackend()

# Mock torch.compile to avoid actual compilation during test
with patch("torch.compile", side_effect=lambda x, **kwargs: x) as mock_compile:
run.deterministic(time, nsteps, model, data, io, device=device, compile=True)

# Verify torch.compile was called
# Note: We check if it was called at least once.
# In deterministic, it's called on prognostic._forward
assert mock_compile.called
# Check if it was called with the model's _forward method
# The first argument to the first call should be a function (the _forward method)
args, kwargs = mock_compile.call_args
assert kwargs.get("mode") == "reduce-overhead"


def test_diagnostic_compile_flag():
"""Verify that the compile flag triggers torch.compile on both models in diagnostic workflow."""
coords = OrderedDict([("lat", np.arange(10)), ("lon", np.arange(20))])
variable = ["t2m"]
nsteps = 1
time = ["2024-01-01"]
device = "cpu"

data = Random(domain_coords=coords)
prognostic = Persistence(variable, coords)

# Simple diagnostic model mock
diagnostic_model = MagicMock()
diagnostic_model.to.return_value = diagnostic_model
diagnostic_model.input_coords.return_value = prognostic.input_coords()
diagnostic_model.output_coords.return_value = prognostic.output_coords(prognostic.input_coords())

io = ZarrBackend()

with patch("torch.compile", side_effect=lambda x, **kwargs: x) as mock_compile:
run.diagnostic(time, nsteps, prognostic, diagnostic_model, data, io, device=device, compile=True)

# Should be called for prognostic._forward and diagnostic_model
assert mock_compile.call_count >= 2


def test_ensemble_compile_flag():
"""Verify that the compile flag triggers torch.compile on the prognostic model in ensemble workflow."""
coords = OrderedDict([("lat", np.arange(10)), ("lon", np.arange(20))])
variable = ["t2m"]
nsteps = 1
nensemble = 2
time = ["2024-01-01"]
device = "cpu"

data = Random(domain_coords=coords)
model = Persistence(variable, coords)

perturbation = MagicMock()
perturbation.side_effect = lambda x, c: (x, c)

io = ZarrBackend()

with patch("torch.compile", side_effect=lambda x, **kwargs: x) as mock_compile:
run.ensemble(time, nsteps, nensemble, model, data, io, perturbation, device=device, compile=True)

assert mock_compile.called
args, kwargs = mock_compile.call_args
assert kwargs.get("mode") == "reduce-overhead"