Skip to content

🚀[FEA]: Add torch.compile support for inference workflows#812

Open
manmeet3591 wants to merge 1 commit into
NVIDIA:mainfrom
manmeet3591:feature/torch-compile-support
Open

🚀[FEA]: Add torch.compile support for inference workflows#812
manmeet3591 wants to merge 1 commit into
NVIDIA:mainfrom
manmeet3591:feature/torch-compile-support

Conversation

@manmeet3591
Copy link
Copy Markdown

@manmeet3591 manmeet3591 commented Apr 13, 2026

This PR addresses Issue #721 by adding optional torch.compile support to the core inference workflows.

Problem

Inference workflows in earth2studio/run.py currently execute autoregressive rollouts using plain Python iteration. For models that repeat the exact same operations at every step (e.g., FCN3, SFNO), this results in significant Python overhead and redundant CUDA kernel launches.

Solution

An optional compile parameter has been added to the following workflows:

  • deterministic()
  • diagnostic()
  • ensemble()

When compile=True, the prognostic model's _forward method is wrapped with torch.compile(mode="reduce-overhead"). In the diagnostic workflow, the diagnostic model is also compiled.

Key Benefits

  • CUDA Graph Capture: Automatically utilizes CUDA Graphs to replay the computation graph, eliminating per-step launch overhead.
  • Speedup: Can yield 1.5–3x speedups for common prognostic models during long rollouts.
  • Minimal Changes: Opt-in feature that preserves existing eager mode behavior by default.

Testing

Added test/run/test_compile.py to verify that the compile flag correctly triggers torch.compile with the expected parameters across all three workflows using mocks.

This commit introduces an optional 'compile' parameter to the
deterministic, diagnostic, and ensemble workflows in run.py.
When enabled, it wraps the prognostic model's _forward method
(and the diagnostic model if applicable) with torch.compile
using 'reduce-overhead' mode. This can significantly speed up
autoregressive rollouts by utilizing CUDA Graphs.

Signed-off-by: Manmeet Singh <manmeet20singh11@gmail.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 13, 2026

Greptile Summary

This PR adds an optional compile=True flag to the deterministic, diagnostic, and ensemble workflows in run.py, wrapping the prognostic model's _forward method (and the diagnostic module) with torch.compile(mode="reduce-overhead").

  • The hasattr(prognostic, \"_forward\") guard is insufficient: models like Pangu, FuXi, and FengWu define _forward but delegate to ONNX Runtime sessions, which torch._dynamo cannot trace. With mode=\"reduce-overhead\" (CUDA Graphs), this produces either a torch._dynamo.exc.Unsupported error at the first real inference step or silent graph-break overhead with no speedup.
  • In the diagnostic workflow, diagnostic = torch.compile(diagnostic, ...) replaces the model object entirely; input_coords() and output_coords() are called on the returned wrapper immediately afterward, which fails for any DiagnosticModel that is not a nn.Module.

Confidence Score: 3/5

Not safe to merge as-is: the compile flag will silently succeed but fail at inference time for several production models (Pangu, FuXi, FengWu).

Two P1 issues: the hasattr(_forward) guard passes for ORT-backed models causing runtime failures when compile=True is used with them, and the whole-object replacement of the diagnostic breaks the non-nn.Module contract. The feature works correctly for pure-PyTorch models (e.g., FCN, SFNO, Persistence), but the guard needs to be tightened before enabling this for general use.

earth2studio/run.py — the compile guard logic and the diagnostic model replacement strategy both need fixes before merging.

Important Files Changed

Filename Overview
earth2studio/run.py Adds optional compile=True flag to all three workflows; the hasattr(_forward) guard is insufficient for ONNX Runtime-backed models (Pangu, FuXi, FengWu), and the diagnostic workflow replaces the entire model object with the compiled wrapper, which breaks input_coords/output_coords access for non-nn.Module diagnostics.
test/run/test_compile.py New test file verifying torch.compile is triggered with the right mode; only covers the Persistence model (compile-compatible), missing coverage for the no-_forward warning path and non-PyTorch backend models.

Reviews (1): Last reviewed commit: "🚀[FEA]: Add torch.compile support for i..." | Re-trigger Greptile

Comment thread earth2studio/run.py
Comment on lines +94 to +98
if hasattr(prognostic, "_forward"):
logger.info("Compiling prognostic model...")
prognostic._forward = torch.compile(
prognostic._forward, mode="reduce-overhead"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 hasattr guard is insufficient for ONNX Runtime-backed models

Pangu, FuXi, and FengWu all define _forward (and therefore pass this check), but their implementations call onnxruntime.InferenceSession.io_binding() and run_with_iobinding() — operations that torch._dynamo cannot trace. With mode="reduce-overhead" (CUDA Graphs), these opaque Python calls will either trigger torch._dynamo.exc.Unsupported at the first inference step, or cause so many graph breaks that CUDA Graphs are never captured and the compilation adds pure overhead with no benefit. The compiled wrapper is silently stored on the instance and the failure only surfaces when the generator yields its first real step.

A minimal guard would check for a marker attribute (e.g., torch_compilable) or check that the model is not ORT-based before attempting to compile. The same issue applies in the diagnostic and ensemble workflows.

Comment thread earth2studio/run.py
output_coords: CoordSystem = OrderedDict({}),
device: torch.device | None = None,
verbose: bool = True,
compile: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Parameter name shadows Python built-in compile

The parameter name compile shadows the Python built-in of the same name. While compile isn't used inside these function bodies, shadowing built-ins makes code harder to read and can confuse linters. Consider renaming to use_compile or torch_compile. The same applies to the diagnostic and ensemble signatures.

Suggested change
compile: bool = False,
use_compile: bool = False,

Comment thread earth2studio/run.py
Comment on lines +230 to +243
if compile:
if hasattr(prognostic, "_forward"):
logger.info("Compiling prognostic model...")
prognostic._forward = torch.compile(
prognostic._forward, mode="reduce-overhead"
)
else:
logger.warning(
"Compilation requested but prognostic model does not have _forward method. Skipping prognostic compilation."
)

# Compile the diagnostic model call
logger.info("Compiling diagnostic model...")
diagnostic = torch.compile(diagnostic, mode="reduce-overhead")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Asymmetric compilation strategy breaks output_coords access on the compiled diagnostic

The prognostic is compiled at the _forward method level (instance attribute patch), while the diagnostic is replaced wholesale with torch.compile(diagnostic, ...). After this line, diagnostic.input_coords() and diagnostic.output_coords(diagnostic_ic) are called on the returned OptimizedModule. This works for nn.Module subclasses because OptimizedModule.__getattr__ proxies to the original module, but the DiagnosticModel protocol does not require nn.Module inheritance. A non-nn.Module diagnostic (e.g., a plain callable that satisfies the protocol) would lose input_coords and output_coords after torch.compile wraps it, raising AttributeError at line 246.

For consistency and safety, compile the diagnostic's forward/__call__ method at the instance level (same approach as for the prognostic) rather than replacing the whole object.

Comment thread test/run/test_compile.py
Comment on lines +36 to +49
device = "cpu"

data = Random(domain_coords=coords)
model = Persistence(variable, coords)
io = ZarrBackend()

# Mock torch.compile to avoid actual compilation during test
with patch("torch.compile", side_effect=lambda x, **kwargs: x) as mock_compile:
run.deterministic(time, nsteps, model, data, io, device=device, compile=True)

# Verify torch.compile was called
# Note: We check if it was called at least once.
# In deterministic, it's called on prognostic._forward
assert mock_compile.called
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test only covers the happy path with a compile-compatible model

All three tests use Persistence, which has a pure-PyTorch _forward. There is no test for:

  • A model without _forward (verifying the logger.warning path fires and inference still completes).
  • A model whose _forward uses non-PyTorch backends (the ORT-compatibility issue noted in run.py).

The compile=False default path (no call to torch.compile) is also not explicitly asserted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant