🚀[FEA]: Add torch.compile support for inference workflows#812
🚀[FEA]: Add torch.compile support for inference workflows#812manmeet3591 wants to merge 1 commit into
Conversation
This commit introduces an optional 'compile' parameter to the deterministic, diagnostic, and ensemble workflows in run.py. When enabled, it wraps the prognostic model's _forward method (and the diagnostic model if applicable) with torch.compile using 'reduce-overhead' mode. This can significantly speed up autoregressive rollouts by utilizing CUDA Graphs. Signed-off-by: Manmeet Singh <manmeet20singh11@gmail.com>
Greptile SummaryThis PR adds an optional
|
| Filename | Overview |
|---|---|
| earth2studio/run.py | Adds optional compile=True flag to all three workflows; the hasattr(_forward) guard is insufficient for ONNX Runtime-backed models (Pangu, FuXi, FengWu), and the diagnostic workflow replaces the entire model object with the compiled wrapper, which breaks input_coords/output_coords access for non-nn.Module diagnostics. |
| test/run/test_compile.py | New test file verifying torch.compile is triggered with the right mode; only covers the Persistence model (compile-compatible), missing coverage for the no-_forward warning path and non-PyTorch backend models. |
Reviews (1): Last reviewed commit: "🚀[FEA]: Add torch.compile support for i..." | Re-trigger Greptile
| if hasattr(prognostic, "_forward"): | ||
| logger.info("Compiling prognostic model...") | ||
| prognostic._forward = torch.compile( | ||
| prognostic._forward, mode="reduce-overhead" | ||
| ) |
There was a problem hiding this comment.
hasattr guard is insufficient for ONNX Runtime-backed models
Pangu, FuXi, and FengWu all define _forward (and therefore pass this check), but their implementations call onnxruntime.InferenceSession.io_binding() and run_with_iobinding() — operations that torch._dynamo cannot trace. With mode="reduce-overhead" (CUDA Graphs), these opaque Python calls will either trigger torch._dynamo.exc.Unsupported at the first inference step, or cause so many graph breaks that CUDA Graphs are never captured and the compilation adds pure overhead with no benefit. The compiled wrapper is silently stored on the instance and the failure only surfaces when the generator yields its first real step.
A minimal guard would check for a marker attribute (e.g., torch_compilable) or check that the model is not ORT-based before attempting to compile. The same issue applies in the diagnostic and ensemble workflows.
| output_coords: CoordSystem = OrderedDict({}), | ||
| device: torch.device | None = None, | ||
| verbose: bool = True, | ||
| compile: bool = False, |
There was a problem hiding this comment.
Parameter name shadows Python built-in
compile
The parameter name compile shadows the Python built-in of the same name. While compile isn't used inside these function bodies, shadowing built-ins makes code harder to read and can confuse linters. Consider renaming to use_compile or torch_compile. The same applies to the diagnostic and ensemble signatures.
| compile: bool = False, | |
| use_compile: bool = False, |
| if compile: | ||
| if hasattr(prognostic, "_forward"): | ||
| logger.info("Compiling prognostic model...") | ||
| prognostic._forward = torch.compile( | ||
| prognostic._forward, mode="reduce-overhead" | ||
| ) | ||
| else: | ||
| logger.warning( | ||
| "Compilation requested but prognostic model does not have _forward method. Skipping prognostic compilation." | ||
| ) | ||
|
|
||
| # Compile the diagnostic model call | ||
| logger.info("Compiling diagnostic model...") | ||
| diagnostic = torch.compile(diagnostic, mode="reduce-overhead") |
There was a problem hiding this comment.
Asymmetric compilation strategy breaks
output_coords access on the compiled diagnostic
The prognostic is compiled at the _forward method level (instance attribute patch), while the diagnostic is replaced wholesale with torch.compile(diagnostic, ...). After this line, diagnostic.input_coords() and diagnostic.output_coords(diagnostic_ic) are called on the returned OptimizedModule. This works for nn.Module subclasses because OptimizedModule.__getattr__ proxies to the original module, but the DiagnosticModel protocol does not require nn.Module inheritance. A non-nn.Module diagnostic (e.g., a plain callable that satisfies the protocol) would lose input_coords and output_coords after torch.compile wraps it, raising AttributeError at line 246.
For consistency and safety, compile the diagnostic's forward/__call__ method at the instance level (same approach as for the prognostic) rather than replacing the whole object.
| device = "cpu" | ||
|
|
||
| data = Random(domain_coords=coords) | ||
| model = Persistence(variable, coords) | ||
| io = ZarrBackend() | ||
|
|
||
| # Mock torch.compile to avoid actual compilation during test | ||
| with patch("torch.compile", side_effect=lambda x, **kwargs: x) as mock_compile: | ||
| run.deterministic(time, nsteps, model, data, io, device=device, compile=True) | ||
|
|
||
| # Verify torch.compile was called | ||
| # Note: We check if it was called at least once. | ||
| # In deterministic, it's called on prognostic._forward | ||
| assert mock_compile.called |
There was a problem hiding this comment.
Test only covers the happy path with a compile-compatible model
All three tests use Persistence, which has a pure-PyTorch _forward. There is no test for:
- A model without
_forward(verifying thelogger.warningpath fires and inference still completes). - A model whose
_forwarduses non-PyTorch backends (the ORT-compatibility issue noted inrun.py).
The compile=False default path (no call to torch.compile) is also not explicitly asserted.
This PR addresses Issue #721 by adding optional
torch.compilesupport to the core inference workflows.Problem
Inference workflows in
earth2studio/run.pycurrently execute autoregressive rollouts using plain Python iteration. For models that repeat the exact same operations at every step (e.g., FCN3, SFNO), this results in significant Python overhead and redundant CUDA kernel launches.Solution
An optional
compileparameter has been added to the following workflows:deterministic()diagnostic()ensemble()When
compile=True, the prognostic model's_forwardmethod is wrapped withtorch.compile(mode="reduce-overhead"). In the diagnostic workflow, the diagnostic model is also compiled.Key Benefits
Testing
Added
test/run/test_compile.pyto verify that thecompileflag correctly triggerstorch.compilewith the expected parameters across all three workflows using mocks.