diff --git a/earth2studio/run.py b/earth2studio/run.py index 706acaa93..a93338b9b 100644 --- a/earth2studio/run.py +++ b/earth2studio/run.py @@ -45,6 +45,7 @@ def deterministic( output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, verbose: bool = True, + compile: bool = False, ) -> IOBackend: """Built in deterministic workflow. This workflow creates a determinstic inference pipeline to produce a forecast @@ -68,6 +69,10 @@ def deterministic( Device to run inference on, by default None verbose : bool, optional Print inference progress, by default True + compile : bool, optional + Use torch.compile to accelerate the prognostic model's forward pass, + by default False. Note: This requires PyTorch >= 2.0 and may incur + initial compilation overhead. Returns ------- @@ -84,6 +89,17 @@ def deterministic( ) logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) + + if compile: + if hasattr(prognostic, "_forward"): + logger.info("Compiling prognostic model...") + prognostic._forward = torch.compile( + prognostic._forward, mode="reduce-overhead" + ) + else: + logger.warning( + "Compilation requested but prognostic model does not have _forward method. Skipping." + ) # sphinx - fetch data start # Fetch data from data source and load onto device prognostic_ic = prognostic.input_coords() @@ -163,6 +179,7 @@ def diagnostic( output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, verbose: bool = True, + compile: bool = False, ) -> IOBackend: """Built in diagnostic workflow. This workflow creates a determinstic inference pipeline that couples a prognostic @@ -188,6 +205,10 @@ def diagnostic( Device to run inference on, by default None verbose : bool, optional Print inference progress, by default True + compile : bool, optional + Use torch.compile to accelerate the prognostic and diagnostic model's + forward pass, by default False. Note: This requires PyTorch >= 2.0 and + may incur initial compilation overhead. Returns ------- @@ -205,6 +226,21 @@ def diagnostic( logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) diagnostic = diagnostic.to(device) + + if compile: + if hasattr(prognostic, "_forward"): + logger.info("Compiling prognostic model...") + prognostic._forward = torch.compile( + prognostic._forward, mode="reduce-overhead" + ) + else: + logger.warning( + "Compilation requested but prognostic model does not have _forward method. Skipping prognostic compilation." + ) + + # Compile the diagnostic model call + logger.info("Compiling diagnostic model...") + diagnostic = torch.compile(diagnostic, mode="reduce-overhead") # Fetch data from data source and load onto device prognostic_ic = prognostic.input_coords() diagnostic_ic = diagnostic.input_coords() @@ -290,6 +326,7 @@ def ensemble( output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, verbose: bool = True, + compile: bool = False, ) -> IOBackend: """Built in ensemble workflow. @@ -318,6 +355,10 @@ def ensemble( Device to run inference on, by default None verbose : bool, optional Print inference progress, by default True + compile : bool, optional + Use torch.compile to accelerate the prognostic model's forward pass, + by default False. Note: This requires PyTorch >= 2.0 and may incur + initial compilation overhead. Returns ------- @@ -336,6 +377,17 @@ def ensemble( logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) + if compile: + if hasattr(prognostic, "_forward"): + logger.info("Compiling prognostic model...") + prognostic._forward = torch.compile( + prognostic._forward, mode="reduce-overhead" + ) + else: + logger.warning( + "Compilation requested but prognostic model does not have _forward method. Skipping." + ) + # Fetch data from data source and load onto device prognostic_ic = prognostic.input_coords() time = to_time_array(time) diff --git a/test/run/test_compile.py b/test/run/test_compile.py new file mode 100644 index 000000000..6558ff686 --- /dev/null +++ b/test/run/test_compile.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch + +import earth2studio.run as run +from earth2studio.data import Random +from earth2studio.io import ZarrBackend +from earth2studio.models.px import Persistence + + +def test_deterministic_compile_flag(): + """Verify that the compile flag triggers torch.compile on the prognostic model's _forward method.""" + coords = OrderedDict([("lat", np.arange(10)), ("lon", np.arange(20))]) + variable = ["t2m"] + nsteps = 1 + time = ["2024-01-01"] + device = "cpu" + + data = Random(domain_coords=coords) + model = Persistence(variable, coords) + io = ZarrBackend() + + # Mock torch.compile to avoid actual compilation during test + with patch("torch.compile", side_effect=lambda x, **kwargs: x) as mock_compile: + run.deterministic(time, nsteps, model, data, io, device=device, compile=True) + + # Verify torch.compile was called + # Note: We check if it was called at least once. + # In deterministic, it's called on prognostic._forward + assert mock_compile.called + # Check if it was called with the model's _forward method + # The first argument to the first call should be a function (the _forward method) + args, kwargs = mock_compile.call_args + assert kwargs.get("mode") == "reduce-overhead" + + +def test_diagnostic_compile_flag(): + """Verify that the compile flag triggers torch.compile on both models in diagnostic workflow.""" + coords = OrderedDict([("lat", np.arange(10)), ("lon", np.arange(20))]) + variable = ["t2m"] + nsteps = 1 + time = ["2024-01-01"] + device = "cpu" + + data = Random(domain_coords=coords) + prognostic = Persistence(variable, coords) + + # Simple diagnostic model mock + diagnostic_model = MagicMock() + diagnostic_model.to.return_value = diagnostic_model + diagnostic_model.input_coords.return_value = prognostic.input_coords() + diagnostic_model.output_coords.return_value = prognostic.output_coords(prognostic.input_coords()) + + io = ZarrBackend() + + with patch("torch.compile", side_effect=lambda x, **kwargs: x) as mock_compile: + run.diagnostic(time, nsteps, prognostic, diagnostic_model, data, io, device=device, compile=True) + + # Should be called for prognostic._forward and diagnostic_model + assert mock_compile.call_count >= 2 + + +def test_ensemble_compile_flag(): + """Verify that the compile flag triggers torch.compile on the prognostic model in ensemble workflow.""" + coords = OrderedDict([("lat", np.arange(10)), ("lon", np.arange(20))]) + variable = ["t2m"] + nsteps = 1 + nensemble = 2 + time = ["2024-01-01"] + device = "cpu" + + data = Random(domain_coords=coords) + model = Persistence(variable, coords) + + perturbation = MagicMock() + perturbation.side_effect = lambda x, c: (x, c) + + io = ZarrBackend() + + with patch("torch.compile", side_effect=lambda x, **kwargs: x) as mock_compile: + run.ensemble(time, nsteps, nensemble, model, data, io, perturbation, device=device, compile=True) + + assert mock_compile.called + args, kwargs = mock_compile.call_args + assert kwargs.get("mode") == "reduce-overhead"