Skip to content

Conversation

@AoyuQC
Copy link

@AoyuQC AoyuQC commented Jan 27, 2026

Replace string-based compiler arguments with a type-safe CompilerConfig dataclass that provides discoverability and easy customization.

Changes:

  • Add CompilerConfig dataclass with fields for common neuronx-cc options (lnc, model_type, auto_cast, enable_mixed_precision_accumulation, etc.)
  • Add factory methods for_nkipy() and for_nki() with appropriate defaults
  • Add get_default_compiler_args() helper to inspect default settings
  • Add compiler_config parameter to @baremetal_jit, baremetal_run_traced_kernel, and DeviceKernel.compile_and_load()
  • Export CompilerConfig and get_default_compiler_args from nkipy.runtime
  • Add tutorial section demonstrating CompilerConfig usage

Backward compatible: legacy additional_compiler_args parameter still works.

Issue #, if available:
N/A

Description of changes:

Replace string-based compiler arguments with a type-safe CompilerConfig dataclass that provides discoverability and easy customization.

Changes:

  • Add CompilerConfig dataclass with fields for common neuronx-cc options (lnc, model_type, auto_cast, enable_mixed_precision_accumulation, etc.)
  • Add factory methods for_nkipy() and for_nki() with appropriate defaults
  • Add get_default_compiler_args() helper to inspect default settings
  • Add compiler_config parameter to @baremetal_jit, baremetal_run_traced_kernel(), and DeviceKernel.compile_and_load()
  • Export CompilerConfig and get_default_compiler_args from nkipy.runtime
  • Add tutorial section demonstrating CompilerConfig usage

Usage:

  from nkipy.runtime import CompilerConfig, baremetal_jit, get_default_compiler_args                                                                                                 
                                                                                                                                                                                     
  # View default compiler args                                                                                                                                                       
  print(get_default_compiler_args())  # "--lnc 1 --internal-tensorizer-opt-level=2"                                                                                                  
                                                                                                                                                                                     
  # Use with decorator                                                                                                                                                               
  @baremetal_jit(compiler_config=CompilerConfig.for_nkipy(model_type="transformer"))                                                                                                 
  def my_kernel(x): ...                                                                                                                                                              
                                                                                                                                                                                     
  # Use with DeviceKernel                                                                                                                                                            
  kernel = DeviceKernel.compile_and_load(                                                                                                                                            
      my_func, x, w,                                                                                                                                                                 
      compiler_config=CompilerConfig.for_nkipy(                                                                                                                                      
          lnc=2,                                                                                                                                                                     
          enable_mixed_precision_accumulation=True,                                                                                                                                  
      )                                                                                                                                                                              
  )   

Backward compatible: legacy additional_compiler_args parameter still works.

Runned Tests:

  • Existing unit tests pass (uv run pytest tests/unit/)
  • Linting passes (uv run ruff check .)
  • Manual verification of CompilerConfig functionality
  • Backward compatibility verified with legacy additional_compiler_args

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@AoyuQC AoyuQC requested a review from a team January 27, 2026 02:48

# Internal/advanced options
if self.tensorizer_opt_level:
args.append(f"--internal-tensorizer-opt-level={self.tensorizer_opt_level}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, already removed options not documented in public api, they can still be passed via extra_args or through pure string mode

    Replace string-based compiler arguments with a type-safe CompilerConfig
    dataclass that provides discoverability and easy customization.

    Changes:
    - Add CompilerConfig dataclass with fields for common neuronx-cc options
      (lnc, model_type, auto_cast, enable_mixed_precision_accumulation, etc.)
    - Add factory methods for_nkipy() and for_nki() with appropriate defaults
    - Add get_default_compiler_args() helper to inspect default settings
    - Add compiler_config parameter to @baremetal_jit, baremetal_run_traced_kernel,
      and DeviceKernel.compile_and_load()
    - Export CompilerConfig and get_default_compiler_args from nkipy.runtime
    - Add tutorial section demonstrating CompilerConfig usage

    Backward compatible: legacy additional_compiler_args parameter still works.
@AoyuQC AoyuQC force-pushed the feat/compiler-config branch from 07f7331 to a8eb03c Compare January 28, 2026 06:19
@AoyuQC AoyuQC requested a review from liangfu January 28, 2026 06:22
Copy link

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants