Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Dec 23, 2025

Description

pytorch-triton and triton packages install to the same location at site-packages/triton, and triton does not work for pytorch's torch.compile() call as there are a few things pytorch has added onto their version of triton (creating pytorch-triton to make it work and validated it with the release of torch). However pytorch-triton should in theory (and experimented) still be compatible with how jax uses it*.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add new env var to control when to use pytorch-triton in jax
  • switch pytorch back to using/checking for pytorch-triton by default
  • Add documentation (comments) on this contention of packages

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…for jax

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung requested a review from ksivaman December 23, 2025 01:18
num_ctas, # arg2: num_ctas (int)
compiled.metadata.shared, # arg3: shared_mem_bytes (int)
compiled.asm["ptx"], # arg4: ptx (str)
"", # arg5: ttir (str) - empty
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will soon be the same as main. as this change is made here in: #1921, to be merged. it is just in this PR so I can test triton calls locally with the nitghtly jax container without running into errors because of jax 0.8.2+

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 23, 2025

Greptile Summary

This PR resolves package conflicts between pytorch-triton and triton by:

  • PyTorch extensions: Now use pytorch-triton by default (required for torch.compile())
  • JAX extensions: Use standard triton by default, with NVTE_USE_PYTORCH_TRITON=1 env var to opt into pytorch-triton for mixed environments
  • Runtime detection: Added _detect_triton_package() and _check_triton_compatibility() to detect which triton variant is installed and warn users appropriately
  • JAX 0.8.2+ compatibility: Added version check for gpu_triton.TritonKernel constructor API changes
  • New API: Added get_triton_info() function to query the installed triton package details
  • Documentation: Comprehensive docstrings explaining the package options and installation requirements

Confidence Score: 4/5

  • This PR is safe to merge with minor documentation cleanup needed
  • The changes are well-structured with comprehensive documentation and proper version detection logic. The only issue found is a placeholder text in a docstring. The core functionality for detecting triton packages and handling JAX API changes appears correct.
  • build_tools/pytorch.py has placeholder text in docstring that should be fixed

Important Files Changed

Filename Overview
build_tools/jax.py Added NVTE_USE_PYTORCH_TRITON env var support to conditionally select triton or pytorch-triton package in test_requirements(). Clear documentation added.
build_tools/pytorch.py Changed from 'triton' to 'pytorch-triton' dependency for PyTorch extensions. Added documentation warning about PyTorch index requirement. Contains placeholder text in docstring.
transformer_engine/jax/triton_extensions/init.py Added comprehensive documentation for triton package options, env vars, and usage examples. Added get_triton_info() to documented API.
transformer_engine/jax/triton_extensions/utils.py Added triton package detection (_detect_triton_package), compatibility check (_check_triton_compatibility), get_triton_info() API, and JAX 0.8.2+ version handling for TritonKernel API.

Sequence Diagram

sequenceDiagram
    participant User
    participant setup.py
    participant build_tools/jax.py
    participant build_tools/pytorch.py
    participant triton_extensions/utils.py

    Note over User,triton_extensions/utils.py: Package Installation Flow
    
    User->>setup.py: pip install transformer-engine[pytorch]
    setup.py->>build_tools/pytorch.py: install_requirements()
    build_tools/pytorch.py-->>setup.py: ["pytorch-triton", ...]
    
    User->>setup.py: pip install transformer-engine[jax]
    setup.py->>build_tools/jax.py: test_requirements()
    build_tools/jax.py->>build_tools/jax.py: Check NVTE_USE_PYTORCH_TRITON
    alt NVTE_USE_PYTORCH_TRITON=1
        build_tools/jax.py-->>setup.py: ["pytorch-triton", ...]
    else Default
        build_tools/jax.py-->>setup.py: ["triton", ...]
    end
    
    Note over User,triton_extensions/utils.py: Runtime Detection Flow
    
    User->>triton_extensions/utils.py: import triton_extensions
    triton_extensions/utils.py->>triton_extensions/utils.py: _detect_triton_package()
    triton_extensions/utils.py->>triton_extensions/utils.py: _check_triton_compatibility()
    alt Placeholder package detected
        triton_extensions/utils.py-->>User: ImportError
    else pytorch-triton without env var
        triton_extensions/utils.py-->>User: UserWarning
    else Valid triton
        triton_extensions/utils.py-->>User: Success
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/triton_extensions/utils.py, line 322 (link)

    syntax: Typo: compile.name should be compiled.name. The variable compile is not defined in this scope - only compiled exists from line 300. This will cause a NameError at runtime for JAX versions < 0.8.2.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. build_tools/pytorch.py, line 21 (link)

    style: Placeholder text <version??> should be replaced with an actual version (e.g., cu121 or cu124) or made generic.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant