diff --git a/nkipy/src/nkipy/__init__.py b/nkipy/src/nkipy/__init__.py index 04f8b7b..42b8b4d 100644 --- a/nkipy/src/nkipy/__init__.py +++ b/nkipy/src/nkipy/__init__.py @@ -1,2 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 + +from nkipy.core.knob import knob + +__all__ = ["knob"] diff --git a/nkipy/src/nkipy/core/backend/__init__.py b/nkipy/src/nkipy/core/backend/__init__.py index 3c5e400..cebb4a1 100644 --- a/nkipy/src/nkipy/core/backend/__init__.py +++ b/nkipy/src/nkipy/core/backend/__init__.py @@ -12,9 +12,124 @@ from __future__ import annotations from contextlib import contextmanager -from typing import Optional, Tuple +from dataclasses import dataclass +from typing import Dict, List, Optional, Protocol, Tuple, runtime_checkable +import numpy as np + +# --------------------------------------------------------------------------- +# Shared IR data types +# --------------------------------------------------------------------------- + + +@dataclass +class TensorPlaceholder: + """Lightweight tensor metadata used by the execution pipeline. + + Attributes: + name: Identifier used to key this tensor in input/output dicts at runtime. + shape: Static shape of the tensor. + dtype: NumPy dtype of the tensor elements. + original_name: User-facing parameter name. Defaults to *name* when not set. + """ + + name: str + shape: Tuple[int, ...] + dtype: np.dtype + original_name: Optional[str] = None + + def __post_init__(self): + if self.original_name is None: + self.original_name = self.name + + +@dataclass(frozen=True) +class AliasInfo: + """One input-output alias pair. + + Attributes: + output_index: Position of this alias in the IR outputs list. + param_index: Position of the aliased parameter in the IR inputs list. + param_name: Name of the aliased input parameter. + is_user_returned: True when the user's kernel explicitly returns this + tensor. False when the framework auto-appended it as an output + solely to write back an in-place mutation. + """ + + output_index: int + param_index: int + param_name: str + is_user_returned: bool + + +# --------------------------------------------------------------------------- +# IR Protocol — the interface that every backend IR must satisfy +# --------------------------------------------------------------------------- + + +@runtime_checkable +class ComputationIR(Protocol): + """Protocol satisfied by both ``HLOModule`` and ``NkiGenIR``.""" + + @property + def inputs(self) -> List[TensorPlaceholder]: ... + + @property + def outputs(self) -> List[TensorPlaceholder]: ... + + @property + def aliases(self) -> List[AliasInfo]: + """Input-output alias pairs for in-place mutations.""" + ... + + @property + def auto_aliased_indices(self) -> set[int]: + """Output indices implicitly appended for write-back, not user-returned.""" + ... + + def content_hash(self, compiler_args: str) -> str: + """Deterministic hash of IR content and compiler flags for caching.""" + ... + + +def prepare_io_mapping( + inputs: List[TensorPlaceholder], + aliases: List[AliasInfo], + original_inputs: Dict[str, np.ndarray], +) -> Tuple[Dict[str, np.ndarray], Dict[int, str]]: + """Map parameter names to backend-specific input names and resolve aliases. + + Args: + inputs: IR input placeholders (from ``ir.inputs``). + aliases: IR alias pairs (from ``ir.aliases``). + original_inputs: User-provided arrays keyed by parameter name. + + Returns: + A tuple of (input_arrays, alias_input_names) where: + - input_arrays maps backend IR input names to numpy arrays. + - alias_input_names maps output index to the IR input name that the + aliased output should share a buffer with. + """ + if len(original_inputs) != len(inputs): + raise RuntimeError( + f"Expected {len(inputs)} tensor arguments, " + f"got {len(original_inputs)}" + ) + input_arrays = { + inp.name: original_inputs[inp.original_name] + for inp in inputs + } + alias_input_names = { + alias.output_index: inputs[alias.param_index].name + for alias in aliases + } + return input_arrays, alias_input_names + + +# --------------------------------------------------------------------------- # Package-private active context — shared with submodules (e.g. hlo.py). +# --------------------------------------------------------------------------- + _active_ctx = None diff --git a/nkipy/src/nkipy/core/backend/hlo.py b/nkipy/src/nkipy/core/backend/hlo.py index a76b1bc..0789930 100644 --- a/nkipy/src/nkipy/core/backend/hlo.py +++ b/nkipy/src/nkipy/core/backend/hlo.py @@ -8,6 +8,7 @@ from __future__ import annotations +import hashlib import struct from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Union @@ -15,6 +16,7 @@ import ml_dtypes import numpy as np +from nkipy.core.backend import AliasInfo, TensorPlaceholder from nkipy.third_party.xla import xla_data_pb2 from nkipy.third_party.xla.service import hlo_pb2 @@ -311,27 +313,6 @@ class HLOTensor: id: Optional[int] = None -@dataclass -class TensorPlaceholder: - """Placeholder for tensor metadata.""" - - name: str - shape: Tuple[int, ...] - dtype: np.dtype - - -@dataclass(frozen=True) -class AliasInfo: - """One input-output alias pair.""" - - output_index: int # Position in HLO output tuple - param_index: int # Position in HLO parameter list - param_name: str # Original parameter name (e.g., "a") - is_user_returned: ( - bool # False = auto-added output, True = user explicitly returned it - ) - - # ============================================================================= # HLO Module # ============================================================================= @@ -362,7 +343,13 @@ def auto_aliased_indices(self) -> set[int]: def inputs(self) -> List[TensorPlaceholder]: """Return parameters as inputs for compatibility with IR Function interface.""" return [ - TensorPlaceholder(name=p.name, shape=p.shape, dtype=p.dtype) + TensorPlaceholder( + name=p.name, + shape=p.shape, + dtype=p.dtype, + # Neuron compiler appends ".must_alias_input" to mutated params + original_name=p.name.split(".must_alias_input")[0], + ) for p in self.parameters ] @@ -407,6 +394,12 @@ def set_results(self, results: Union[HLOTensor, List[HLOTensor]]) -> None: """Set the output results of the module.""" self.results = results if isinstance(results, list) else [results] + def content_hash(self, compiler_args: str) -> str: + h = hashlib.sha256() + h.update(self.to_proto().SerializeToString()) + h.update(compiler_args.encode("utf-8")) + return h.hexdigest()[:12] + # ========================================================================= # Proto Generation # ========================================================================= diff --git a/nkipy/src/nkipy/core/backend/nkigen.py b/nkipy/src/nkipy/core/backend/nkigen.py new file mode 100644 index 0000000..97e6eaa --- /dev/null +++ b/nkipy/src/nkipy/core/backend/nkigen.py @@ -0,0 +1,180 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""NkiGen backend for NKIPy. + +This module provides the nkigen backend by delegating to +``nkigen.builder`` for all MLIR construction. No MLIR types +are imported or exposed — the builder API is the sole interface. +""" + +from __future__ import annotations + +import hashlib +from typing import List + +import numpy as np + +from nkipy.core.backend import AliasInfo, TensorPlaceholder + + +# --------------------------------------------------------------------------- +# NkiGenTensor -- analogue of HLOTensor +# --------------------------------------------------------------------------- + +class NkiGenTensor: + """Backend tensor for the nkigen backend. + + Wraps an opaque ``TensorHandle`` from ``nkigen.builder`` + with the metadata that ``NKIPyTensorRef`` expects. + """ + + __slots__ = ("handle", "shape", "dtype", "is_parameter", "parameter_id", "name", "id") + + _next_id = 0 + + def __init__(self, handle, shape, dtype, *, is_parameter=False, parameter_id=None, name=""): + self.handle = handle + self.shape = tuple(shape) + self.dtype = np.dtype(dtype) if not isinstance(dtype, np.dtype) else dtype + self.is_parameter = is_parameter + self.parameter_id = parameter_id + self.name = name + self.id = NkiGenTensor._next_id + NkiGenTensor._next_id += 1 + + +# --------------------------------------------------------------------------- +# NkiGenTraceContext +# --------------------------------------------------------------------------- + +class NkiGenTraceContext: + """Trace context that delegates to ``nkigen.builder.IRBuilder``.""" + + backend_name = "nkigen" + + def __init__(self): + from nkigen.builder import IRBuilder + self._builder = IRBuilder() + self._parameters: List[NkiGenTensor] = [] + self.current_source_location = None + + @property + def module(self): + """Return the underlying MLIR module from the builder.""" + return self._builder.module + + def set_source_location(self, location): + """Set the current source location for diagnostic tracking.""" + self.current_source_location = location + + def _begin_function(self, name, arg_shapes, arg_dtypes): + """Start an MLIR function and return parameter tensors.""" + handles = self._builder.begin_function(name, arg_shapes, arg_dtypes) + tensors = [] + for i, (h, (shape, dtype)) in enumerate( + zip(handles, zip(arg_shapes, arg_dtypes)) + ): + kt = NkiGenTensor( + h, shape, dtype, + is_parameter=True, parameter_id=i, name=f"arg{i}" + ) + self._parameters.append(kt) + tensors.append(kt) + return tensors + + def _finish_function(self, result_tensors): + """Finalize the MLIR function with the given result tensors.""" + self._builder.finish_function([t.handle for t in result_tensors]) + + def _run_canonicalize(self): + """Run MLIR canonicalization passes on the module.""" + self._builder.run_canonicalize() + + def _get_ir_text(self): + """Export the MLIR module as a text string.""" + return self._builder.get_ir_text() + + def _cleanup(self): + """Release builder resources.""" + self._builder.cleanup() + + +# --------------------------------------------------------------------------- +# Module-level context accessor +# --------------------------------------------------------------------------- + +def get_nkigen_context() -> NkiGenTraceContext: + """Return the active ``NkiGenTraceContext``, or raise if none is active.""" + from nkipy.core.backend import _active_ctx + if _active_ctx is None or _active_ctx.backend_name != "nkigen": + raise RuntimeError("No active nkigen trace context") + return _active_ctx + + +# --------------------------------------------------------------------------- +# NkiGenIR -- make MLIR IR compatible with execution pipeline +# --------------------------------------------------------------------------- + + +class NkiGenIR: + """Adapter that makes an MLIR module compatible with the execution pipeline. + + Provides the same interface as ``HLOModule`` (``.inputs``, ``.outputs``, + ``.aliases``, ``.auto_aliased_indices``) so that ``compile.py`` and + ``execute.py`` can handle both backends uniformly. + """ + + def __init__(self, mlir_text, func_name, input_specs, output_specs, + alias_map=None, user_return_len=None, original_param_names=None): + self._mlir_text = mlir_text + self._func_name = func_name + self._input_specs = input_specs # [(name, shape, dtype), ...] + self._output_specs = output_specs # [(name, shape, dtype), ...] + # alias_map: {output_index: (param_name, param_index)} + self._alias_map = alias_map or {} + self._user_return_len = user_return_len if user_return_len is not None else len(output_specs) + # Positionally aligned with _input_specs: original_param_names[i] is the + # user-facing parameter name for the i-th NEFF input ("in_tensor_i"). + self._original_param_names = original_param_names or [] + + @property + def inputs(self): + """Return input tensor metadata as ``TensorPlaceholder`` list.""" + return [ + TensorPlaceholder(n, tuple(s), np.dtype(d), original_name=self._original_param_names[i]) + for i, (n, s, d) in enumerate(self._input_specs) + ] + + @property + def outputs(self): + """Return output tensor metadata as ``TensorPlaceholder`` list.""" + return [TensorPlaceholder(n, tuple(s), np.dtype(d)) for n, s, d in self._output_specs] + + @property + def aliases(self): + """Return input-output alias pairs as ``AliasInfo`` list.""" + return [ + AliasInfo( + output_index=out_idx, + param_index=pidx, + param_name=pname, + is_user_returned=out_idx < self._user_return_len, + ) + for out_idx, (pname, pidx) in self._alias_map.items() + ] + + @property + def auto_aliased_indices(self): + """Output indices that were auto-added (not user-returned).""" + return { + out_idx for out_idx in self._alias_map + if out_idx >= self._user_return_len + } + + def content_hash(self, compiler_args: str) -> str: + """Compute a content hash from the MLIR text and compiler args.""" + h = hashlib.sha256() + h.update(self._mlir_text.encode("utf-8")) + h.update(compiler_args.encode("utf-8")) + return h.hexdigest()[:12] + diff --git a/nkipy/src/nkipy/core/compile.py b/nkipy/src/nkipy/core/compile.py index b8a024a..15b144e 100644 --- a/nkipy/src/nkipy/core/compile.py +++ b/nkipy/src/nkipy/core/compile.py @@ -112,89 +112,67 @@ class Compiler: def __init__(self, config: CompilationConfig): self.config = config - def _build_compile_command(self, mode="hlo") -> List[str]: + def _resolve_target(self) -> CompilationTarget: + if self.config.target != CompilationTarget.DEFAULT: + return self.config.target + try: + return get_platform_target() + except Exception: + logging.warning( + "Failed to detect platform target, falling back to trn2..." + ) + return CompilationTarget.TRN2 + + def _build_hlo_compile_command(self, work_dir: Path) -> List[str]: + """Build the neuronx-cc command line for HLO compilation.""" + target = self._resolve_target() + self.config.target = target + cmd = [ - "neuronx-cc", - "compile", - "--framework", - "XLA", + "neuronx-cc", "compile", + "--framework", "XLA", + str(work_dir / "hlo_module.pb"), + "--pipeline", *self.config.pipeline, + "--target", target.value, + f"--output={self.config.neff_name}", ] - if mode == "hlo": - cmd.extend(["hlo_module.pb"]) - else: - raise RuntimeError(f"Unknown mode: {mode}") - - cmd.append("--pipeline") - cmd.extend(self.config.pipeline) - - # When using default target, detect platform target - if self.config.target == CompilationTarget.DEFAULT: - try: - self.config.target = get_platform_target() - except Exception: - logging.warning( - "Failed to detect platform target, falling back to trn1..." - ) - self.config.target = CompilationTarget.TRN1 - - cmd.extend( - ["--target", self.config.target.value, f"--output={self.config.neff_name}"] - ) if self.config.additional_args: cmd.extend(shlex.split(self.config.additional_args)) return cmd - def compile( + @staticmethod + def _compilation_error(message, cmd=None, result=None): + """Build a RuntimeError with compiler output when available.""" + parts = [message] + if cmd is not None: + parts.append(f"Command: {' '.join(cmd)}") + if result is not None: + def decode(b): + return b.decode("utf-8", errors="replace") if b else "" + parts.append(f"stderr:\n{decode(result.stderr)}") + parts.append(f"stdout:\n{decode(result.stdout)}") + return RuntimeError("\n".join(parts)) + + def _compile_hlo( self, ir, work_dir: Path, output_file: str, use_neuronx_cc_python_interface: bool = False, ) -> Path: - """ - Run compilation in specified directory - - Args: - ir: The IR to compile - work_dir: Directory to compile in - output_file: Name of the output file to check for ("file.neff" or "nki.py") + """Compile an HLOModule to NEFF via neuronx-cc.""" + hlo_pb_path = work_dir / "hlo_module.pb" + proto = ir.to_proto() + with open(hlo_pb_path, "wb") as f: + f.write(proto.SerializeToString()) - Returns: - Path to the output file - """ - - mode = "hlo" if isinstance(ir, HLOModule) else "unknown" - cmd = self._build_compile_command(mode) - - def _compilation_error(message, result=None): - """Build a RuntimeError with compiler output when available.""" - parts = [message, f"Command: {' '.join(cmd)}"] - if result is not None: - - def decode(b): - return b.decode("utf-8", errors="replace") if b else "" - - parts.append(f"stderr:\n{decode(result.stderr)}") - parts.append(f"stdout:\n{decode(result.stdout)}") - return RuntimeError("\n".join(parts)) + cmd = self._build_hlo_compile_command(work_dir) current_dir = os.getcwd() try: os.chdir(work_dir) - if mode == "hlo": - hlo_pb_path = "hlo_module.pb" - proto = ir.to_proto() - with open(hlo_pb_path, "wb") as f: - f.write(proto.SerializeToString()) - else: - raise RuntimeError( - f"Unknown mode: {mode}. " - "Note: For NKI kernels, You can either embed a NKI kernel as an op" - " in NKIPy kernel or implement your own helper function to get the" - " NEFF from a NKI kernel." - ) if use_neuronx_cc_python_interface: original_argv = sys.argv.copy() sys.argv = cmd @@ -202,9 +180,9 @@ def decode(b): else: result = subprocess.run(cmd, capture_output=True) if result.returncode != 0: - raise _compilation_error( + raise self._compilation_error( f"Compilation failed (exit code {result.returncode}).", - result, + cmd, result, ) finally: if use_neuronx_cc_python_interface: @@ -213,13 +191,66 @@ def decode(b): output_path = work_dir / output_file if not output_path.exists(): - raise _compilation_error( + raise self._compilation_error( f"Compilation failed: {output_file} expected but not generated.", + cmd, result if not use_neuronx_cc_python_interface else None, ) + return output_path + + def _compile_nkigen(self, ir, work_dir: Path, output_file: str) -> Path: + """Compile a NkiGenIR module to NEFF via nkigen.""" + from nkigen.compile import compile_to_neff + + target_str = self._resolve_target().value + cc_args = tuple(shlex.split(self.config.additional_args)) if self.config.additional_args else () + + compile_to_neff( + ir._mlir_text, + ir._func_name, + input_specs=[(s.name, s.shape, s.dtype) for s in ir.inputs], + output_specs=[(s.name, s.shape, s.dtype) for s in ir.outputs], + target=target_str, + output_path=str(work_dir / output_file), + artifacts_dir=str(work_dir), + neuronx_cc_args=cc_args, + ) + + output_path = work_dir / output_file + if not output_path.exists(): + raise self._compilation_error( + f"NkiGen compilation failed: {output_file} not generated." + ) return output_path + def compile( + self, + ir, + work_dir: Path, + output_file: str, + use_neuronx_cc_python_interface: bool = False, + ) -> Path: + """Compile an IR module to a NEFF file. + + Dispatches to ``_compile_hlo`` or ``_compile_nkigen`` based on + the IR type. + """ + if isinstance(ir, HLOModule): + return self._compile_hlo( + ir, work_dir, output_file, use_neuronx_cc_python_interface + ) + + from nkipy.core.backend.nkigen import NkiGenIR + + if isinstance(ir, NkiGenIR): + return self._compile_nkigen(ir, work_dir, output_file) + + raise RuntimeError( + f"Unknown IR type: {type(ir).__name__}. " + "Expected HLOModule or NkiGenIR." + ) + def compile_in_directory( self, ir, diff --git a/nkipy/src/nkipy/core/knob.py b/nkipy/src/nkipy/core/knob.py new file mode 100644 index 0000000..5764f3d --- /dev/null +++ b/nkipy/src/nkipy/core/knob.py @@ -0,0 +1,70 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Public knob() API for annotating tensors with hardware placement and tiling hints. + +Dispatches based on the active tracing backend: +- nkigen: emits nkipy.AnnotateOp into MLIR +- hlo: warns and ignores +- cpu / no trace: no-op pass-through +""" + +from __future__ import annotations + +import warnings +from typing import List, Optional + + +def knob( + tensor, + *, + partition_dim: Optional[int] = None, + mem_space: Optional[str] = None, + tile_size: Optional[List[int]] = None, + reduction_tile: Optional[List[int]] = None, +): + """Annotate a tensor with hardware placement and tiling hints. + + Only effective when using the nkigen backend. When used with the HLO + backend, issues a warning and returns the tensor unchanged. + + Args: + tensor: The tensor to annotate. + partition_dim: Dimension to partition (must be < tensor rank). + mem_space: Memory space ("Hbm", "Psum", "Sbuf", or "SharedHbm"). + tile_size: Tile sizes for each dimension. + reduction_tile: Tile sizes for reduction dimensions (e.g., K in matmul). + + Returns: + The same tensor, unchanged. + """ + from nkipy.core.backend import get_backend + + backend = get_backend() + + if backend == "nkigen": + from nkipy.core.tensor import NKIPyTensorRef + + if not isinstance(tensor, NKIPyTensorRef): + return tensor + + if mem_space is None and partition_dim is None and tile_size is None and reduction_tile is None: + return tensor + + import nkigen.builder as B + + B.annotate( + tensor.backend_tensor.handle, + partition_dim=partition_dim, + mem_space=mem_space, + tile_size=tile_size, + reduction_tile=reduction_tile, + ) + return tensor + elif backend == "hlo": + warnings.warn( + "knob() annotations are only effective with backend='nkigen'. " + "Ignoring annotation.", + stacklevel=2, + ) + + return tensor diff --git a/nkipy/src/nkipy/core/nki_op.py b/nkipy/src/nkipy/core/nki_op.py index 3ac472b..c042d87 100644 --- a/nkipy/src/nkipy/core/nki_op.py +++ b/nkipy/src/nkipy/core/nki_op.py @@ -1,8 +1,8 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""NKI kernel integration for NKIPy - wraps NKI kernels as HLO custom-calls +"""NKI kernel integration for NKIPy. -This module provides two ways to use NKI kernels in NKIPy: +This module provides three ways to use NKI kernels in NKIPy: 1. Direct @nki.jit support (lazy/dynamic): - Any kernel decorated with @nki.jit can be called directly during NKIPy tracing @@ -14,6 +14,10 @@ - Returns a NKICustomOp that only works with those shapes - Useful for explicit control over specialization +3. nki_custom_op for cross-backend custom ops: + - Accepts both @nki.jit (HLO backend) and kernel_builder (nkigen backend) + - Dispatches to the correct implementation based on the active backend + Supports two NKI frontends: - Legacy frontend (neuronxcc.nki): Default, supports CPU execution - Beta 2 frontend (nki): New frontend, hardware-only (no CPU execution support) @@ -21,7 +25,7 @@ import dataclasses import inspect -from typing import Callable, Iterable, Optional, Tuple +from typing import Callable, Iterable, List, Optional, Tuple import numpy as np @@ -274,10 +278,10 @@ def _patched_beta2_generic_kernel_call(self, *args, **kwargs): class NKICustomOp: - """Backward-compatible NKI custom op class. + """HLO custom-call wrapper for a pre-traced NKI kernel. - This class provides the original API for wrapping NKI kernels. - New code should use wrap_nki_kernel() or direct @nki.jit instead. + Pre-traces the kernel at construction time for specific operand shapes. + Used by ``wrap_nki_kernel``. """ def __init__( @@ -362,3 +366,111 @@ def wrap_nki_kernel( is_nki_beta_2_version=is_nki_beta_2_version, platform_target=platform_target, ) + + +# --------------------------------------------------------------------------- +# NkiGen custom op support +# --------------------------------------------------------------------------- + + +def _generate_nkigen_custom_call(kernel_builder, input_specs, output_specs, *args): + """Compile a kernel_builder function and inline it during nkigen tracing.""" + from nkigen.builder import apply_custom_op + + return apply_custom_op( + kernel_builder=kernel_builder, + reference_fn=None, + input_specs=input_specs, + output_specs=output_specs, + args=args, + ) + + +# --------------------------------------------------------------------------- +# Unified custom op interface +# --------------------------------------------------------------------------- + + +def nki_custom_op( + *, + nki_kernel: Optional[Callable] = None, + kernel_builder: Optional[Callable] = None, + input_specs: Optional[List[Tuple[Tuple[int, ...], str]]] = None, + output_specs: Optional[List[Tuple[Tuple[int, ...], str]]] = None, +) -> "NKICustomOpHandle": + """Create a cross-backend custom NKI op. + + Args: + nki_kernel: ``@nki.jit`` decorated kernel for the HLO backend. + kernel_builder: ``nki.compiler.kernel_builder`` function for the + nkigen backend. Requires ``input_specs`` and ``output_specs``. + input_specs: List of ``((shape), dtype_str)`` for each input. + Required when ``kernel_builder`` is provided. + output_specs: List of ``((shape), dtype_str)`` for each output. + Required when ``kernel_builder`` is provided. + + Returns: + An ``NKICustomOpHandle`` callable that dispatches to the correct + backend at call time. + """ + if nki_kernel is None and kernel_builder is None: + raise ValueError( + "At least one of nki_kernel or kernel_builder must be provided." + ) + if kernel_builder is not None: + if input_specs is None or output_specs is None: + raise ValueError( + "input_specs and output_specs are required when kernel_builder " + "is provided." + ) + return NKICustomOpHandle( + nki_kernel=nki_kernel, + kernel_builder=kernel_builder, + input_specs=input_specs, + output_specs=output_specs, + ) + + +class NKICustomOpHandle: + """Backend-aware callable wrapping a custom NKI op definition.""" + + def __init__( + self, + *, + nki_kernel: Optional[Callable], + kernel_builder: Optional[Callable], + input_specs: Optional[List[Tuple[Tuple[int, ...], str]]], + output_specs: Optional[List[Tuple[Tuple[int, ...], str]]], + ): + self._nki_kernel = nki_kernel + self._kernel_builder = kernel_builder + self._input_specs = input_specs + self._output_specs = output_specs + + def __call__(self, *args): + backend = get_backend() + + if backend == "hlo": + if self._nki_kernel is None: + raise RuntimeError( + "nki_custom_op has no nki_kernel for the HLO backend. " + "Provide an @nki.jit decorated kernel via nki_kernel=." + ) + return _generate_nki_custom_call(self._nki_kernel, *args) + + if backend == "nkigen": + if self._kernel_builder is None: + raise RuntimeError( + "nki_custom_op has no kernel_builder for the nkigen " + "backend. Provide a kernel_builder function via " + "kernel_builder=." + ) + return _generate_nkigen_custom_call( + self._kernel_builder, self._input_specs, self._output_specs, + *args, + ) + + raise RuntimeError( + f"nki_custom_op is not supported on backend '{backend}'. " + f"Use the 'hlo' or 'nkigen' backend." + ) diff --git a/nkipy/src/nkipy/core/ops/_hlo_impls.py b/nkipy/src/nkipy/core/ops/_hlo_impls.py new file mode 100644 index 0000000..eddcb59 --- /dev/null +++ b/nkipy/src/nkipy/core/ops/_hlo_impls.py @@ -0,0 +1,2817 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""HLO backend implementations for NKIPy ops. + +Contains all primitive HLO lowerings. Composed ops (floor_divide, tan, rint, +etc.) are registered as ``composed_impl`` on the Op itself and need no +per-backend registration. +""" + +from __future__ import annotations + +import builtins +import itertools +from typing import List, Tuple + +import numpy as np + +from nkipy.core.backend.hlo import ( + HLOOp, + as_hlo_tensor, + broadcast_operands_hlo, + broadcast_to_shape_hlo, + find_common_type_hlo, + get_hlo_context, +) +from nkipy.core.tensor import NKIPyTensorRef + +builtins_min = builtins.min + +# ============================================================================= +# Binary ops +# ============================================================================= + + +def _build_binary_hlo(x, y, np_op, out=None, dtype=None): + ctx = get_hlo_context() + + promoted_dtype = find_common_type_hlo(x, y) + + if dtype is not None: + output_dtype = np.dtype(dtype) + else: + output_dtype = promoted_dtype + + x = ( + x.backend_tensor + if isinstance(x, NKIPyTensorRef) + else as_hlo_tensor(ctx, x, promoted_dtype) + ) + y = ( + y.backend_tensor + if isinstance(y, NKIPyTensorRef) + else as_hlo_tensor(ctx, y, promoted_dtype) + ) + + op_map = { + np.add: "add", + np.subtract: "subtract", + np.multiply: "multiply", + np.divide: "divide", + np.power: "power", + np.maximum: "maximum", + np.minimum: "minimum", + np.bitwise_and: "and", + np.bitwise_or: "or", + np.bitwise_xor: "xor", + np.logical_and: "and", + np.logical_or: "or", + np.logical_xor: "xor", + } + + hlo_op = op_map.get( + np_op, np_op.__name__ if hasattr(np_op, "__name__") else str(np_op) + ) + + x_broadcast, y_broadcast = broadcast_operands_hlo(ctx, x, y) + + if output_dtype != x_broadcast.dtype: + x_broadcast = ctx.build_op( + "convert", [x_broadcast], x_broadcast.shape, output_dtype + ) + if output_dtype != y_broadcast.dtype: + y_broadcast = ctx.build_op( + "convert", [y_broadcast], y_broadcast.shape, output_dtype + ) + + result_tensor = ctx.build_op( + hlo_op, [x_broadcast, y_broadcast], x_broadcast.shape, output_dtype + ) + + return NKIPyTensorRef(result_tensor) + + +def _build_comparison_hlo(x, y, np_op, out=None, dtype=None): + ctx = get_hlo_context() + + promoted_dtype = find_common_type_hlo(x, y) + + x = ( + x.backend_tensor + if isinstance(x, NKIPyTensorRef) + else as_hlo_tensor(ctx, x, promoted_dtype) + ) + y = ( + y.backend_tensor + if isinstance(y, NKIPyTensorRef) + else as_hlo_tensor(ctx, y, promoted_dtype) + ) + + if x.dtype != promoted_dtype: + x = ctx.build_op("convert", [x], x.shape, promoted_dtype) + if y.dtype != promoted_dtype: + y = ctx.build_op("convert", [y], y.shape, promoted_dtype) + + comp_map = { + np.equal: "EQ", + np.not_equal: "NE", + np.less: "LT", + np.less_equal: "LE", + np.greater: "GT", + np.greater_equal: "GE", + } + + comp_dir = comp_map.get(np_op, "EQ") + + x_broadcast, y_broadcast = broadcast_operands_hlo(ctx, x, y) + + result_tensor = ctx.build_op( + "compare", + [x_broadcast, y_broadcast], + x_broadcast.shape, + np.bool_, + {"comparison_direction": comp_dir}, + ) + + return NKIPyTensorRef(result_tensor) + + +def _build_logical_hlo(x, y, hlo_op_name, out=None, dtype=None): + ctx = get_hlo_context() + + promoted_dtype = find_common_type_hlo(x, y) + + x = ( + x.backend_tensor + if isinstance(x, NKIPyTensorRef) + else as_hlo_tensor(ctx, x, promoted_dtype) + ) + y = ( + y.backend_tensor + if isinstance(y, NKIPyTensorRef) + else as_hlo_tensor(ctx, y, promoted_dtype) + ) + + x_broadcast, y_broadcast = broadcast_operands_hlo(ctx, x, y) + + zero_x = as_hlo_tensor(ctx, 0, x_broadcast.dtype) + if x_broadcast.shape: + zero_x = ctx.build_op( + "broadcast", + [zero_x], + x_broadcast.shape, + x_broadcast.dtype, + {"broadcast_dimensions": []}, + ) + x_bool = ctx.build_op( + "compare", + [x_broadcast, zero_x], + x_broadcast.shape, + np.bool_, + {"comparison_direction": "NE"}, + ) + + zero_y = as_hlo_tensor(ctx, 0, y_broadcast.dtype) + if y_broadcast.shape: + zero_y = ctx.build_op( + "broadcast", + [zero_y], + y_broadcast.shape, + y_broadcast.dtype, + {"broadcast_dimensions": []}, + ) + y_bool = ctx.build_op( + "compare", + [y_broadcast, zero_y], + y_broadcast.shape, + np.bool_, + {"comparison_direction": "NE"}, + ) + + result_tensor = ctx.build_op( + hlo_op_name, [x_bool, y_bool], x_broadcast.shape, np.bool_ + ) + + return NKIPyTensorRef(result_tensor) + + +def add(x, y, out=None, dtype=None): + return _build_binary_hlo(x, y, np.add, out=out, dtype=dtype) + + +def subtract(x, y, out=None, dtype=None): + return _build_binary_hlo(x, y, np.subtract, out=out, dtype=dtype) + + +def multiply(x, y, out=None, dtype=None): + return _build_binary_hlo(x, y, np.multiply, out=out, dtype=dtype) + + +def divide(x, y, out=None, dtype=None): + return _build_binary_hlo(x, y, np.divide, out=out, dtype=dtype) + + +def power(x, y, out=None, dtype=None): + return _build_binary_hlo(x, y, np.power, out=out, dtype=dtype) + + +def maximum(x, y, out=None, dtype=None): + return _build_binary_hlo(x, y, np.maximum, out=out, dtype=dtype) + + +def minimum(x, y, out=None, dtype=None): + return _build_binary_hlo(x, y, np.minimum, out=out, dtype=dtype) + + +def bitwise_and(x, y, out=None, dtype=None): + return _build_binary_hlo(x, y, np.bitwise_and, out=out, dtype=dtype) + + +def bitwise_or(x, y, out=None, dtype=None): + return _build_binary_hlo(x, y, np.bitwise_or, out=out, dtype=dtype) + + +def bitwise_xor(x, y, out=None, dtype=None): + return _build_binary_hlo(x, y, np.bitwise_xor, out=out, dtype=dtype) + + +def equal(x, y, out=None, dtype=None): + return _build_comparison_hlo(x, y, np.equal, out=out, dtype=dtype) + + +def not_equal(x, y, out=None, dtype=None): + return _build_comparison_hlo(x, y, np.not_equal, out=out, dtype=dtype) + + +def greater(x, y, out=None, dtype=None): + return _build_comparison_hlo(x, y, np.greater, out=out, dtype=dtype) + + +def greater_equal(x, y, out=None, dtype=None): + return _build_comparison_hlo(x, y, np.greater_equal, out=out, dtype=dtype) + + +def less(x, y, out=None, dtype=None): + return _build_comparison_hlo(x, y, np.less, out=out, dtype=dtype) + + +def less_equal(x, y, out=None, dtype=None): + return _build_comparison_hlo(x, y, np.less_equal, out=out, dtype=dtype) + + +def logical_and(x, y, out=None, dtype=None): + return _build_logical_hlo(x, y, "and", out=out, dtype=dtype) + + +def logical_or(x, y, out=None, dtype=None): + return _build_logical_hlo(x, y, "or", out=out, dtype=dtype) + + +def logical_xor(x, y, out=None, dtype=None): + return _build_logical_hlo(x, y, "xor", out=out, dtype=dtype) + + +# ============================================================================= +# Unary ops +# ============================================================================= + + +def _build_unary_hlo(x, np_op, out=None, dtype=None): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + if np_op == np.arctan: + one_tensor = as_hlo_tensor(ctx, 1.0, x.dtype) + if x.shape: + one_tensor = ctx.build_op( + "broadcast", + [one_tensor], + x.shape, + x.dtype, + {"broadcast_dimensions": []}, + ) + result_tensor = ctx.build_op("atan2", [x, one_tensor], x.shape, x.dtype) + return NKIPyTensorRef(result_tensor) + + op_map = { + np.abs: "abs", + np.exp: "exponential", + np.log: "log", + np.sqrt: "sqrt", + np.sin: "sine", + np.cos: "cosine", + np.tanh: "tanh", + np.negative: "negate", + np.ceil: "ceil", + np.floor: "floor", + np.sign: "sign", + np.bitwise_not: "not", + np.invert: "not", + } + + hlo_op = op_map.get( + np_op, np_op.__name__ if hasattr(np_op, "__name__") else str(np_op) + ) + result_tensor = ctx.build_op(hlo_op, [x], x.shape, x.dtype) + + return NKIPyTensorRef(result_tensor) + + +def abs(x, out=None, dtype=None): + return _build_unary_hlo(x, np.abs, out=out, dtype=dtype) + + +def exp(x, out=None, dtype=None): + return _build_unary_hlo(x, np.exp, out=out, dtype=dtype) + + +def log(x, out=None, dtype=None): + return _build_unary_hlo(x, np.log, out=out, dtype=dtype) + + +def sqrt(x, out=None, dtype=None): + return _build_unary_hlo(x, np.sqrt, out=out, dtype=dtype) + + +def sin(x, out=None, dtype=None): + return _build_unary_hlo(x, np.sin, out=out, dtype=dtype) + + +def cos(x, out=None, dtype=None): + return _build_unary_hlo(x, np.cos, out=out, dtype=dtype) + + +def tanh(x, out=None, dtype=None): + return _build_unary_hlo(x, np.tanh, out=out, dtype=dtype) + + +def ceil(x, out=None, dtype=None): + return _build_unary_hlo(x, np.ceil, out=out, dtype=dtype) + + +def floor(x, out=None, dtype=None): + return _build_unary_hlo(x, np.floor, out=out, dtype=dtype) + + +def sign(x, out=None, dtype=None): + return _build_unary_hlo(x, np.sign, out=out, dtype=dtype) + + +def negative(x, out=None, dtype=None): + return _build_unary_hlo(x, np.negative, out=out, dtype=dtype) + + +def arctan(x, out=None, dtype=None): + return _build_unary_hlo(x, np.arctan, out=out, dtype=dtype) + + +def invert(x, out=None, dtype=None): + return _build_unary_hlo(x, np.invert, out=out, dtype=dtype) + + +def bitwise_not(x, out=None, dtype=None): + return _build_unary_hlo(x, np.bitwise_not, out=out, dtype=dtype) + + + +# ============================================================================= +# Reduction ops +# ============================================================================= + + +def _build_reduction_hlo( + x, np_op, axis=None, out=None, dtype=None, keepdims=False, initial=None +): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + reduce_op_map = { + np.sum: "add", + np.max: "maximum", + np.min: "minimum", + np.prod: "multiply", + } + + if np_op not in reduce_op_map: + raise NotImplementedError( + f"Reduction operation {np_op} not yet supported in HLO tracing" + ) + + hlo_op = reduce_op_map[np_op] + + if axis is None: + dimensions_to_reduce = tuple(range(len(x.shape))) + elif isinstance(axis, int): + dim = axis if axis >= 0 else len(x.shape) + axis + dimensions_to_reduce = (dim,) + elif isinstance(axis, (list, tuple)): + dimensions_to_reduce = tuple( + ax if ax >= 0 else len(x.shape) + ax for ax in axis + ) + else: + dimensions_to_reduce = (axis,) + + reduced_shape = tuple( + s for i, s in enumerate(x.shape) if i not in dimensions_to_reduce + ) + + init_values = { + "add": 0.0, + "maximum": float("-inf"), + "minimum": float("inf"), + "multiply": 1.0, + } + init_value = init_values[hlo_op] + + init_tensor = as_hlo_tensor(ctx, init_value, x.dtype) + + result_tensor = ctx.build_op( + "reduce", + [x, init_tensor], + reduced_shape, + x.dtype, + { + "dimensions": list(dimensions_to_reduce), + "computation": hlo_op, + }, + ) + + if keepdims: + keepdims_shape = tuple( + 1 if i in dimensions_to_reduce else s for i, s in enumerate(x.shape) + ) + result_tensor = ctx.build_op( + "reshape", [result_tensor], keepdims_shape, x.dtype + ) + + return NKIPyTensorRef(result_tensor) + + +def _calculate_reduction_count(x_shape, axis): + if axis is None: + return int(np.prod(x_shape)) + elif isinstance(axis, int): + dim = axis if axis >= 0 else len(x_shape) + axis + return x_shape[dim] + elif isinstance(axis, (list, tuple)): + return int( + np.prod([x_shape[ax if ax >= 0 else len(x_shape) + ax] for ax in axis]) + ) + else: + return x_shape[axis] + + +def reduce_sum(x, axis=None, out=None, dtype=None, keepdims=False, initial=None): + return _build_reduction_hlo( + x, np.sum, axis=axis, out=out, dtype=dtype, keepdims=keepdims, initial=initial, + ) + + +def reduce_prod(x, axis=None, out=None, dtype=None, keepdims=False, initial=None): + return _build_reduction_hlo( + x, np.prod, axis=axis, out=out, dtype=dtype, keepdims=keepdims, initial=initial, + ) + + +def reduce_max(x, axis=None, out=None, dtype=None, keepdims=False, initial=None): + return _build_reduction_hlo( + x, np.max, axis=axis, out=out, dtype=dtype, keepdims=keepdims, initial=initial, + ) + + +def reduce_min(x, axis=None, out=None, dtype=None, keepdims=False, initial=None): + return _build_reduction_hlo( + x, np.min, axis=axis, out=out, dtype=dtype, keepdims=keepdims, initial=initial, + ) + + +def argmax(x, axis=None, out=None, keepdims=False): + from nkipy.core.ops.reduce import max as max_op, min as min_op + + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_ref = x + x_bt = x.backend_tensor + else: + x_ref = NKIPyTensorRef(x) + x_bt = x + + original_axis = axis + x_ref_original_shape = x_ref.shape + + if axis is None: + from nkipy.core.ops.transform import reshape as reshape_op + + total = int(np.prod(x_ref.shape)) + x_ref = reshape_op(x_ref, (total,)) + x_bt = x_ref.backend_tensor + axis = 0 + + ndim = len(x_bt.shape) + if axis < 0: + axis = ndim + axis + + max_val = max_op(x_ref, axis=axis, keepdims=True) + + from nkipy.core.ops.binary import equal as equal_op + + mask = equal_op(x_ref, max_val) + + iota_tensor = ctx.build_op( + "iota", + [], + x_bt.shape, + np.dtype(np.float32), + {"iota_dimension": axis}, + ) + iota_ref = NKIPyTensorRef(iota_tensor) + + large_val = float(x_bt.shape[axis] + 1) + from nkipy.core.ops.indexing import where as where_op + + masked_indices = where_op(mask, iota_ref, large_val) + + result_float = min_op(masked_indices, axis=axis) + + from nkipy.core.ops.transform import astype as astype_op + + result = astype_op(result_float, np.dtype(np.int32)) + + if keepdims: + from nkipy.core.ops.transform import reshape as reshape_op + + if original_axis is not None: + keepdims_shape = list(x_ref_original_shape) + keepdims_shape[original_axis] = 1 + result = reshape_op(result, tuple(keepdims_shape)) + else: + keepdims_shape = tuple(1 for _ in x_ref_original_shape) + result = reshape_op(result, keepdims_shape) + + return result + + +def argmin(x, axis=None, out=None, keepdims=False): + from nkipy.core.ops.reduce import min as min_op + + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_ref = x + x_bt = x.backend_tensor + else: + x_ref = NKIPyTensorRef(x) + x_bt = x + + original_axis = axis + x_ref_original_shape = x_ref.shape + + if axis is None: + from nkipy.core.ops.transform import reshape as reshape_op + + total = int(np.prod(x_ref.shape)) + x_ref = reshape_op(x_ref, (total,)) + x_bt = x_ref.backend_tensor + axis = 0 + + ndim = len(x_bt.shape) + if axis < 0: + axis = ndim + axis + + min_val = min_op(x_ref, axis=axis, keepdims=True) + + from nkipy.core.ops.binary import equal as equal_op + + mask = equal_op(x_ref, min_val) + + iota_tensor = ctx.build_op( + "iota", + [], + x_bt.shape, + np.dtype(np.float32), + {"iota_dimension": axis}, + ) + iota_ref = NKIPyTensorRef(iota_tensor) + + large_val = float(x_bt.shape[axis] + 1) + from nkipy.core.ops.indexing import where as where_op + + masked_indices = where_op(mask, iota_ref, large_val) + + result_float = min_op(masked_indices, axis=axis) + + from nkipy.core.ops.transform import astype as astype_op + + result = astype_op(result_float, np.dtype(np.int32)) + + if keepdims: + from nkipy.core.ops.transform import reshape as reshape_op + + if original_axis is not None: + keepdims_shape = list(x_ref_original_shape) + keepdims_shape[original_axis] = 1 + result = reshape_op(result, tuple(keepdims_shape)) + else: + keepdims_shape = tuple(1 for _ in x_ref_original_shape) + result = reshape_op(result, keepdims_shape) + + return result + + + +# ============================================================================= +# Linalg ops +# ============================================================================= + + +def matmul(x, y, out=None, dtype=None): + ctx = get_hlo_context() + + result_dtype = find_common_type_hlo(x, y) + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + if isinstance(y, NKIPyTensorRef): + y = y.backend_tensor + + assert len(x.shape) >= 1 and len(y.shape) >= 1, "matmul requires at least 1D arrays" + + squeeze_lhs = False + squeeze_rhs = False + + if len(x.shape) == 1 and len(y.shape) == 1: + assert x.shape[0] == y.shape[0], "Incompatible shapes for dot product" + result_shape = () + lhs_contracting_dims = [0] + rhs_contracting_dims = [0] + lhs_batch_dims = [] + rhs_batch_dims = [] + else: + if len(x.shape) == 1: + x = ctx.build_op("reshape", [x], (1, x.shape[0]), x.dtype) + squeeze_lhs = True + + if len(y.shape) == 1: + y = ctx.build_op("reshape", [y], (y.shape[0], 1), y.dtype) + squeeze_rhs = True + + assert x.shape[-1] == y.shape[-2], "Incompatible shapes for matmul" + + x_batch_shape = x.shape[:-2] + y_batch_shape = y.shape[:-2] + batch_shape = tuple(np.broadcast_shapes(x_batch_shape, y_batch_shape)) + result_shape = batch_shape + (x.shape[-2], y.shape[-1]) + + target_x_shape = batch_shape + tuple(x.shape[-2:]) + target_y_shape = batch_shape + tuple(y.shape[-2:]) + + if x.shape != target_x_shape: + x = broadcast_to_shape_hlo(ctx, x, target_x_shape) + + if y.shape != target_y_shape: + y = broadcast_to_shape_hlo(ctx, y, target_y_shape) + + lhs_contracting_dims = [len(target_x_shape) - 1] + rhs_contracting_dims = [len(target_y_shape) - 2] + + lhs_batch_dims = list(range(len(batch_shape))) + rhs_batch_dims = list(range(len(batch_shape))) + + result_tensor = ctx.build_op( + "dot", + [x, y], + result_shape, + result_dtype, + { + "lhs_contracting_dimensions": lhs_contracting_dims, + "rhs_contracting_dimensions": rhs_contracting_dims, + "lhs_batch_dimensions": lhs_batch_dims, + "rhs_batch_dimensions": rhs_batch_dims, + }, + ) + + if squeeze_lhs or squeeze_rhs: + final_shape = list(result_shape) + if squeeze_lhs: + final_shape.pop(-2) + if squeeze_rhs: + final_shape.pop(-1) + final_shape = tuple(final_shape) + result_tensor = ctx.build_op( + "reshape", [result_tensor], final_shape, result_dtype + ) + + return NKIPyTensorRef(result_tensor) + + +def trace(a, offset=0, axis1=0, axis2=1, dtype=None): + from nkipy.core.ops.binary import equal as equal_op, subtract as sub_op + from nkipy.core.ops.indexing import where as where_op + from nkipy.core.ops.reduce import sum as sum_op + + ctx = get_hlo_context() + + if isinstance(a, NKIPyTensorRef): + a_bt = a.backend_tensor + else: + a_bt = a + + shape = a_bt.shape + ndim = len(shape) + + if axis1 < 0: + axis1 += ndim + if axis2 < 0: + axis2 += ndim + + row_iota = ctx.build_op( + "iota", [], shape, np.dtype(np.int32), {"iota_dimension": axis1} + ) + row_ref = NKIPyTensorRef(row_iota) + + col_iota = ctx.build_op( + "iota", [], shape, np.dtype(np.int32), {"iota_dimension": axis2} + ) + col_ref = NKIPyTensorRef(col_iota) + + if offset != 0: + col_ref = sub_op(col_ref, offset) + + diag_mask = equal_op(row_ref, col_ref) + + masked = where_op(diag_mask, a, 0.0) + + axes_to_reduce = sorted([axis1, axis2], reverse=True) + result = masked + for ax in axes_to_reduce: + result = sum_op(result, axis=ax) + + if dtype is not None: + from nkipy.core.ops.transform import astype as astype_op + + result = astype_op(result, np.dtype(dtype)) + + return result + + +def dot(x, y, out=None): + ctx = get_hlo_context() + + result_dtype = find_common_type_hlo(x, y) + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + if isinstance(y, NKIPyTensorRef): + y = y.backend_tensor + + assert len(x.shape) >= 1 and len(y.shape) >= 1, "dot requires at least 1D arrays" + + lhs_contracting_dims = [len(x.shape) - 1] + rhs_contracting_dims = [max(0, len(y.shape) - 2)] + + assert x.shape[lhs_contracting_dims[0]] == y.shape[rhs_contracting_dims[0]], ( + f"shapes {x.shape} and {y.shape} not aligned" + ) + + lhs_batch_dims = [] + rhs_batch_dims = [] + + result_shape = tuple( + s for i, s in enumerate(x.shape) if i not in lhs_contracting_dims + ) + tuple(s for i, s in enumerate(y.shape) if i not in rhs_contracting_dims) + + result_tensor = ctx.build_op( + "dot", + [x, y], + result_shape, + result_dtype, + { + "lhs_contracting_dimensions": lhs_contracting_dims, + "rhs_contracting_dimensions": rhs_contracting_dims, + "lhs_batch_dimensions": lhs_batch_dims, + "rhs_batch_dimensions": rhs_batch_dims, + }, + ) + + return NKIPyTensorRef(result_tensor) + + +# ============================================================================= +# Creation ops +# ============================================================================= + + +def zeros(shape, dtype): + ctx = get_hlo_context() + + if isinstance(shape, int): + shape = (shape,) + + zero_tensor = as_hlo_tensor(ctx, 0.0, dtype) + + if shape: + result_tensor = ctx.build_op( + "broadcast", [zero_tensor], shape, dtype, {"broadcast_dimensions": []} + ) + else: + result_tensor = zero_tensor + + return NKIPyTensorRef(result_tensor) + + +def full(shape, fill_value, dtype): + ctx = get_hlo_context() + + if isinstance(shape, int): + shape = (shape,) + + fill_tensor = as_hlo_tensor(ctx, fill_value, dtype) + + if shape: + result_tensor = ctx.build_op( + "broadcast", [fill_tensor], shape, dtype, {"broadcast_dimensions": []} + ) + else: + result_tensor = fill_tensor + + return NKIPyTensorRef(result_tensor) + + +def constant(value, dtype=None): + if isinstance(value, NKIPyTensorRef): + if dtype is not None and value.dtype != np.dtype(dtype): + from nkipy.core.ops.transform import astype as astype_op + + return astype_op(value, dtype) + return value + + ctx = get_hlo_context() + + if dtype is not None: + target_dtype = np.dtype(dtype) + elif hasattr(value, "dtype"): + target_dtype = np.dtype(value.dtype) + elif isinstance(value, float): + target_dtype = np.dtype(np.float32) + elif isinstance(value, int): + target_dtype = np.dtype(np.int32) + elif isinstance(value, bool): + target_dtype = np.dtype(np.bool_) + else: + target_dtype = np.dtype(np.asarray(value).dtype) + + if isinstance(value, (list, tuple)): + value = np.asarray(value, dtype=target_dtype) + + hlo_tensor = as_hlo_tensor(ctx, value, target_dtype) + return NKIPyTensorRef(hlo_tensor) + + +def zeros_like(x, dtype=None): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_hlo = x.backend_tensor + else: + x_hlo = x + + result_dtype = dtype if dtype is not None else x_hlo.dtype + + zero_tensor = as_hlo_tensor(ctx, 0.0, result_dtype) + + if x_hlo.shape: + result_tensor = ctx.build_op( + "broadcast", + [zero_tensor], + x_hlo.shape, + result_dtype, + {"broadcast_dimensions": []}, + ) + else: + result_tensor = zero_tensor + + # FIXME: Workaround to ensure x is referenced in the computation graph. + zero_multiplier = as_hlo_tensor(ctx, 0.0, x_hlo.dtype) + if x_hlo.shape: + zero_multiplier = ctx.build_op( + "broadcast", + [zero_multiplier], + x_hlo.shape, + x_hlo.dtype, + {"broadcast_dimensions": []}, + ) + + x_times_zero = ctx.build_op( + "multiply", [x_hlo, zero_multiplier], x_hlo.shape, x_hlo.dtype + ) + + if x_hlo.dtype != result_dtype: + x_times_zero = ctx.build_op( + "convert", [x_times_zero], x_hlo.shape, result_dtype + ) + + result_tensor = ctx.build_op( + "add", [result_tensor, x_times_zero], x_hlo.shape, result_dtype + ) + + return NKIPyTensorRef(result_tensor) + + +def ones_like(x, dtype=None): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_hlo = x.backend_tensor + else: + x_hlo = x + + result_dtype = dtype if dtype is not None else x_hlo.dtype + + one_tensor = as_hlo_tensor(ctx, 1.0, result_dtype) + + if x_hlo.shape: + result_tensor = ctx.build_op( + "broadcast", + [one_tensor], + x_hlo.shape, + result_dtype, + {"broadcast_dimensions": []}, + ) + else: + result_tensor = one_tensor + + # FIXME: Workaround to ensure x is referenced in the computation graph + zero_multiplier = as_hlo_tensor(ctx, 0.0, x_hlo.dtype) + if x_hlo.shape: + zero_multiplier = ctx.build_op( + "broadcast", + [zero_multiplier], + x_hlo.shape, + x_hlo.dtype, + {"broadcast_dimensions": []}, + ) + + x_times_zero = ctx.build_op( + "multiply", [x_hlo, zero_multiplier], x_hlo.shape, x_hlo.dtype + ) + + if x_hlo.dtype != result_dtype: + x_times_zero = ctx.build_op( + "convert", [x_times_zero], x_hlo.shape, result_dtype + ) + + result_tensor = ctx.build_op( + "add", [result_tensor, x_times_zero], x_hlo.shape, result_dtype + ) + + return NKIPyTensorRef(result_tensor) + + +def empty_like(x, dtype=None): + return zeros_like(x, dtype=dtype) + + +def full_like(x, fill_value, dtype=None): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_hlo = x.backend_tensor + else: + x_hlo = x + + result_dtype = dtype if dtype is not None else x_hlo.dtype + + fill_tensor = as_hlo_tensor(ctx, fill_value, result_dtype) + + if x_hlo.shape: + result_tensor = ctx.build_op( + "broadcast", + [fill_tensor], + x_hlo.shape, + result_dtype, + {"broadcast_dimensions": []}, + ) + else: + result_tensor = fill_tensor + + # FIXME: Workaround to ensure x is referenced in the computation graph + zero_multiplier = as_hlo_tensor(ctx, 0.0, x_hlo.dtype) + if x_hlo.shape: + zero_multiplier = ctx.build_op( + "broadcast", + [zero_multiplier], + x_hlo.shape, + x_hlo.dtype, + {"broadcast_dimensions": []}, + ) + + x_times_zero = ctx.build_op( + "multiply", [x_hlo, zero_multiplier], x_hlo.shape, x_hlo.dtype + ) + + if x_hlo.dtype != result_dtype: + x_times_zero = ctx.build_op( + "convert", [x_times_zero], x_hlo.shape, result_dtype + ) + + result_tensor = ctx.build_op( + "add", [result_tensor, x_times_zero], x_hlo.shape, result_dtype + ) + + return NKIPyTensorRef(result_tensor) + + +def tril(x, k=0): + from nkipy.core.ops.binary import greater_equal as ge_op, subtract as sub_op + from nkipy.core.ops.indexing import where as where_op + + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_bt = x.backend_tensor + else: + x_bt = x + + shape = x_bt.shape + ndim = len(shape) + + row_iota = ctx.build_op( + "iota", [], shape, np.dtype(np.int32), {"iota_dimension": ndim - 2} + ) + row_ref = NKIPyTensorRef(row_iota) + + col_iota = ctx.build_op( + "iota", [], shape, np.dtype(np.int32), {"iota_dimension": ndim - 1} + ) + col_ref = NKIPyTensorRef(col_iota) + + if k != 0: + col_ref = sub_op(col_ref, k) + mask = ge_op(row_ref, col_ref) + + return where_op(mask, x, 0.0) + + +def triu(x, k=0): + from nkipy.core.ops.binary import less_equal as le_op, subtract as sub_op + from nkipy.core.ops.indexing import where as where_op + + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_bt = x.backend_tensor + else: + x_bt = x + + shape = x_bt.shape + ndim = len(shape) + + row_iota = ctx.build_op( + "iota", [], shape, np.dtype(np.int32), {"iota_dimension": ndim - 2} + ) + row_ref = NKIPyTensorRef(row_iota) + + col_iota = ctx.build_op( + "iota", [], shape, np.dtype(np.int32), {"iota_dimension": ndim - 1} + ) + col_ref = NKIPyTensorRef(col_iota) + + if k != 0: + col_ref = sub_op(col_ref, k) + mask = le_op(row_ref, col_ref) + + return where_op(mask, x, 0.0) + + +def diag(v, k=0): + from nkipy.core.ops.binary import equal as equal_op, subtract as sub_op + from nkipy.core.ops.indexing import take as take_op, where as where_op + from nkipy.core.ops.reduce import sum as sum_op + from nkipy.core.ops.transform import broadcast_to as bcast_op, reshape as reshape_op + from nkipy.core.ops.unary import clip as clip_op + + ctx = get_hlo_context() + + if isinstance(v, NKIPyTensorRef): + v_bt = v.backend_tensor + else: + v_bt = v + + ndim = len(v_bt.shape) + + if ndim == 1: + n = v_bt.shape[0] + builtins.abs(k) + shape_2d = (n, n) + + row_iota = ctx.build_op( + "iota", [], shape_2d, np.dtype(np.int32), {"iota_dimension": 0} + ) + row_ref = NKIPyTensorRef(row_iota) + col_iota = ctx.build_op( + "iota", [], shape_2d, np.dtype(np.int32), {"iota_dimension": 1} + ) + col_ref = NKIPyTensorRef(col_iota) + + if k != 0: + col_ref = sub_op(col_ref, k) + + diag_mask = equal_op(row_ref, col_ref) + + if k >= 0: + idx_ref = row_ref + else: + idx_ref = NKIPyTensorRef(col_iota) + + idx_ref = clip_op(idx_ref, 0, v_bt.shape[0] - 1) + + v_gathered = take_op(v, idx_ref, axis=0) + return where_op(diag_mask, v_gathered, 0.0) + + elif ndim == 2: + rows, cols = v_bt.shape + if k >= 0: + diag_len = builtins_min(rows, cols - k) + else: + diag_len = builtins_min(rows + k, cols) + + if diag_len <= 0: + return zeros((0,), v_bt.dtype) + + shape_2d = v_bt.shape + row_iota = ctx.build_op( + "iota", [], shape_2d, np.dtype(np.int32), {"iota_dimension": 0} + ) + row_ref = NKIPyTensorRef(row_iota) + col_iota = ctx.build_op( + "iota", [], shape_2d, np.dtype(np.int32), {"iota_dimension": 1} + ) + col_ref = NKIPyTensorRef(col_iota) + + if k != 0: + col_ref = sub_op(col_ref, k) + + diag_mask = equal_op(row_ref, col_ref) + + masked = where_op(diag_mask, v, 0.0) + + if k >= 0: + result = sum_op(masked, axis=1) + else: + result = sum_op(masked, axis=0) + + result_shape = result.shape + if result_shape[0] != diag_len: + from nkipy.core.ops.indexing import static_slice as static_slice_op + + result = static_slice_op(result, [0], [diag_len], [1], []) + + return result + + else: + raise ValueError(f"Input must be 1-D or 2-D, got {ndim}-D") + + +# ============================================================================= +# Transform ops +# ============================================================================= + + +def reshape(x, newshape): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + if isinstance(newshape, int): + newshape = (newshape,) + + if -1 in newshape: + total_size = int(np.prod(x.shape)) + known_size = int(np.prod([d for d in newshape if d != -1])) + assert known_size > 0, "Cannot reshape to a size of 0" + assert total_size % known_size == 0, ( + f"Cannot reshape array of size {total_size} into shape {newshape}" + ) + newshape = tuple(total_size // known_size if d == -1 else d for d in newshape) + + if np.prod(x.shape) != np.prod(newshape): + raise ValueError( + f"Cannot reshape array of size {np.prod(x.shape)} into shape {newshape}" + ) + + result_tensor = ctx.build_op("reshape", [x], newshape, x.dtype) + return NKIPyTensorRef(result_tensor) + + +def transpose(x, axes=None, out=None, dtype=None): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + if axes is None: + axes = list(range(len(x.shape)))[::-1] + + result_shape = tuple(x.shape[i] for i in axes) + + result_tensor = ctx.build_op( + "transpose", [x], result_shape, x.dtype, {"permutation": axes} + ) + return NKIPyTensorRef(result_tensor) + + +def swapaxes(x, axis1, axis2): + ndim = len(x.shape) + if axis1 < 0: + axis1 += ndim + if axis2 < 0: + axis2 += ndim + if not (0 <= axis1 < ndim) or not (0 <= axis2 < ndim): + raise np.AxisError(f"axis is out of bounds for array of dimension {ndim}") + axes = list(range(ndim)) + axes[axis1], axes[axis2] = axes[axis2], axes[axis1] + from nkipy.core.ops.transform import transpose as transpose_op + + return transpose_op(x, axes=axes) + + +def stack(arrays, axis=0, out=None, dtype=None): + from nkipy.core.ops.transform import concatenate as concat_op, expand_dims as expand_op + + expanded = [expand_op(a, axis=axis) for a in arrays] + return concat_op(expanded, axis=axis) + + +def expand_dims(x, axis): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + rank = len(x.shape) + + if isinstance(axis, (list, tuple)): + final_rank = rank + len(axis) + + axes = [] + for ax in axis: + if ax < 0: + ax = final_rank + ax + if ax < 0 or ax > final_rank - 1: + raise ValueError( + f"axis {ax} is out of bounds for array of dimension {final_rank}" + ) + axes.append(ax) + + if len(axes) != len(set(axes)): + raise ValueError("repeated axis in expand_dims") + + axes = sorted(axes) + + new_shape = list(x.shape) + for ax in axes: + new_shape.insert(ax, 1) + new_shape = tuple(new_shape) + else: + if axis < 0: + axis = rank + axis + 1 + + if axis < 0 or axis > rank: + raise ValueError( + f"axis {axis} is out of bounds for array of dimension {rank}" + ) + + new_shape = list(x.shape) + new_shape.insert(axis, 1) + new_shape = tuple(new_shape) + + result_tensor = ctx.build_op("reshape", [x], new_shape, x.dtype) + return NKIPyTensorRef(result_tensor) + + +def concatenate(tensors, axis=0): + ctx = get_hlo_context() + + hlo_tensors = [] + for t in tensors: + if isinstance(t, NKIPyTensorRef): + hlo_tensors.append(t.backend_tensor) + elif isinstance(t, np.ndarray): + from nkipy.core.ops.creation import constant as constant_op + + const_ref = constant_op(t) + hlo_tensors.append(const_ref.backend_tensor) + else: + hlo_tensors.append(t) + + if not hlo_tensors: + raise ValueError("Need at least one tensor to concatenate") + + if len(hlo_tensors) == 1: + result_tensor = ctx.build_op( + "copy", [hlo_tensors[0]], hlo_tensors[0].shape, hlo_tensors[0].dtype + ) + return NKIPyTensorRef(result_tensor) + + ndim = len(hlo_tensors[0].shape) + if axis < 0: + axis = ndim + axis + + if axis < 0 or axis >= ndim: + raise ValueError(f"axis {axis} is out of bounds for array of dimension {ndim}") + + output_shape = list(hlo_tensors[0].shape) + output_shape[axis] = builtins.sum(t.shape[axis] for t in hlo_tensors) + output_shape = tuple(output_shape) + + dtype = hlo_tensors[0].dtype + for t in hlo_tensors[1:]: + dtype = np.result_type(dtype, t.dtype) + + result_tensor = ctx.build_op( + "concatenate", hlo_tensors, output_shape, dtype, {"dimension": axis} + ) + return NKIPyTensorRef(result_tensor) + + +def split(x, indices_or_sections, axis=0): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + if axis < 0: + axis = len(x.shape) + axis + + if axis < 0 or axis >= len(x.shape): + raise ValueError( + f"axis {axis} is out of bounds for array of dimension {len(x.shape)}" + ) + + axis_size = x.shape[axis] + + if isinstance(indices_or_sections, int): + n_sections = indices_or_sections + if n_sections <= 0: + raise ValueError("Number of sections must be larger than 0") + if axis_size % n_sections != 0: + raise ValueError("Array split does not result in an equal division") + + section_size = axis_size // n_sections + split_indices = [i * section_size for i in range(1, n_sections)] + else: + split_indices = list(indices_or_sections) + + split_points = [0] + split_indices + [axis_size] + + result_tensors = [] + for i in range(len(split_points) - 1): + start_idx = split_points[i] + end_idx = split_points[i + 1] + + start_indices = [0] * len(x.shape) + limit_indices = list(x.shape) + strides = [1] * len(x.shape) + + start_indices[axis] = start_idx + limit_indices[axis] = end_idx + + slice_shape = list(x.shape) + slice_shape[axis] = end_idx - start_idx + slice_shape = tuple(slice_shape) + + slice_tensor = ctx.build_op( + "slice", + [x], + slice_shape, + x.dtype, + { + "start_indices": start_indices, + "limit_indices": limit_indices, + "strides": strides, + }, + ) + + result_tensors.append(NKIPyTensorRef(slice_tensor)) + + return result_tensors + + +def copy(x, out=None, dtype=None): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + result_tensor = ctx.build_op("copy", [x], x.shape, x.dtype) + return NKIPyTensorRef(result_tensor) + + +def repeat(x, repeats, axis=None): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + if axis is None: + flattened_shape = (int(np.prod(x.shape)),) + x = ctx.build_op("reshape", [x], flattened_shape, x.dtype) + axis = 0 + + if axis < 0: + axis = len(x.shape) + axis + + if not isinstance(repeats, (int, np.integer)): + raise TypeError( + f"Only compile-time-known integer repeats are supported, got {type(repeats).__name__}. " + "Dynamic tensor repeats are not supported in tracing." + ) + repeats = int(repeats) + + new_shape = list(x.shape) + new_shape[axis] *= repeats + new_shape = tuple(new_shape) + + broadcast_shape = list(x.shape) + broadcast_shape.insert(axis + 1, repeats) + broadcast_shape = tuple(broadcast_shape) + + broadcast_dims = [i if i <= axis else i + 1 for i in range(len(x.shape))] + x_broadcast = ctx.build_op( + "broadcast", + [x], + broadcast_shape, + x.dtype, + {"broadcast_dimensions": broadcast_dims}, + ) + + result_tensor = ctx.build_op("reshape", [x_broadcast], new_shape, x.dtype) + + return NKIPyTensorRef(result_tensor) + + +def broadcast_to(x, shape, out=None, dtype=None): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + if isinstance(shape, int): + shape = (shape,) + target_shape = tuple(shape) + + if x.shape == target_shape: + result_tensor = ctx.build_op("copy", [x], x.shape, x.dtype) + return NKIPyTensorRef(result_tensor) + + result_tensor = broadcast_to_shape_hlo(ctx, x, target_shape) + return NKIPyTensorRef(result_tensor) + + +def astype(x, dtype): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + if x.dtype == dtype: + result_tensor = ctx.build_op("copy", [x], x.shape, x.dtype) + else: + result_tensor = ctx.build_op("convert", [x], x.shape, dtype) + + return NKIPyTensorRef(result_tensor) + + +def squeeze(x, axis=None): + x_shape = x.shape + ndim = len(x_shape) + + if axis is None: + new_shape = tuple(s for s in x_shape if s != 1) + if not new_shape: + new_shape = () + else: + if isinstance(axis, int): + axis = (axis,) + axes = tuple(a if a >= 0 else ndim + a for a in axis) + for a in axes: + if x_shape[a] != 1: + raise ValueError( + f"cannot select an axis to squeeze out which has size " + f"not equal to one, got shape[{a}] = {x_shape[a]}" + ) + new_shape = tuple(s for i, s in enumerate(x_shape) if i not in axes) + + if new_shape == x_shape: + return x + from nkipy.core.ops.transform import reshape as reshape_op + + return reshape_op(x, new_shape) + + +def pad(x, pad_width, mode="constant", constant_values=0, **kwargs): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_shape = x.shape + x_dtype = x.dtype + x_bt = x.backend_tensor + else: + x_shape = x.shape + x_dtype = x.dtype + x_bt = x + + ndim = len(x_shape) + + pad_width_arr = np.asarray(pad_width) + if pad_width_arr.ndim == 0: + pad_width_arr = np.broadcast_to(pad_width_arr, (ndim, 2)) + elif pad_width_arr.ndim == 1: + if len(pad_width_arr) == 2: + pad_width_arr = np.broadcast_to(pad_width_arr, (ndim, 2)) + else: + pad_width_arr = np.array([[p, p] for p in pad_width_arr]) + if len(pad_width_arr) != ndim: + raise ValueError( + f"pad_width must have length {ndim} to match array dimensions, " + f"got {len(pad_width_arr)}" + ) + if pad_width_arr.ndim == 2 and len(pad_width_arr) == 1: + pad_width_arr = np.broadcast_to(pad_width_arr, (ndim, 2)) + pad_width_list = [(int(pad_width_arr[i, 0]), int(pad_width_arr[i, 1])) for i in range(ndim)] + + if mode == "constant": + padding_config = [(low, high, 0) for low, high in pad_width_list] + + result_shape = tuple( + s + low + high for s, (low, high) in zip(x_shape, pad_width_list) + ) + + pad_value_tensor = as_hlo_tensor(ctx, constant_values, x_dtype) + + result_tensor = ctx.build_op( + "pad", + [x_bt, pad_value_tensor], + result_shape, + x_dtype, + {"padding_config": padding_config}, + ) + return NKIPyTensorRef(result_tensor) + + elif mode == "edge": + from nkipy.core.ops.transform import ( + concatenate as concat_op, + expand_dims as expand_op, + repeat as repeat_op, + ) + + result = NKIPyTensorRef(x_bt) if not isinstance(x, NKIPyTensorRef) else x + for dim in range(ndim): + before, after = pad_width_list[dim] + if before == 0 and after == 0: + continue + + parts = [] + if before > 0: + edge_slice = _slice_single(result, dim, 0) + edge_expanded = expand_op(edge_slice, axis=dim) + edge_repeated = repeat_op(edge_expanded, before, axis=dim) + parts.append(edge_repeated) + + parts.append(result) + + if after > 0: + last_idx = result.shape[dim] - 1 + edge_slice = _slice_single(result, dim, last_idx) + edge_expanded = expand_op(edge_slice, axis=dim) + edge_repeated = repeat_op(edge_expanded, after, axis=dim) + parts.append(edge_repeated) + + result = concat_op(parts, axis=dim) + + return result + + else: + raise NotImplementedError( + f"Pad mode '{mode}' is not supported. Only 'constant' and 'edge' modes are available." + ) + + +def diff(a, n=1, axis=-1, prepend=None, append=None): + from nkipy.core.ops.binary import subtract as sub_op + + ctx = get_hlo_context() + + ndim = len(a.shape) + if axis < 0: + axis += ndim + + result = a + for _ in range(n): + if isinstance(result, NKIPyTensorRef): + r_bt = result.backend_tensor + else: + r_bt = result + + axis_size = r_bt.shape[axis] + + start1 = [0] * ndim + limit1 = list(r_bt.shape) + start1[axis] = 1 + shape1 = list(r_bt.shape) + shape1[axis] = axis_size - 1 + + t1 = ctx.build_op( + "slice", + [r_bt], + tuple(shape1), + r_bt.dtype, + {"start_indices": start1, "limit_indices": limit1, "strides": [1] * ndim}, + ) + + start0 = [0] * ndim + limit0 = list(r_bt.shape) + limit0[axis] = axis_size - 1 + shape0 = list(r_bt.shape) + shape0[axis] = axis_size - 1 + + t0 = ctx.build_op( + "slice", + [r_bt], + tuple(shape0), + r_bt.dtype, + {"start_indices": start0, "limit_indices": limit0, "strides": [1] * ndim}, + ) + + result = sub_op(NKIPyTensorRef(t1), NKIPyTensorRef(t0)) + + return result + + +def flip(x, axis=None): + from nkipy.core.ops.indexing import take as take_op + + ndim = len(x.shape) + + if axis is None: + axes = list(range(ndim)) + elif isinstance(axis, int): + axes = [axis if axis >= 0 else axis + ndim] + else: + axes = [a if a >= 0 else a + ndim for a in axis] + + result = x + for ax in axes: + n = result.shape[ax] + reversed_indices = np.arange(n - 1, -1, -1, dtype=np.int32) + result = take_op(result, reversed_indices, axis=ax) + + return result + + +def tile(x, reps): + from nkipy.core.ops.transform import ( + broadcast_to as bcast_op, + copy as copy_op, + reshape as reshape_op, + ) + + if isinstance(reps, int): + reps = (reps,) + reps = tuple(reps) + + x_shape = x.shape + ndim = len(x_shape) + + if len(reps) < ndim: + reps = (1,) * (ndim - len(reps)) + reps + elif len(reps) > ndim: + x = reshape_op(x, (1,) * (len(reps) - ndim) + x_shape) + x_shape = x.shape + ndim = len(x_shape) + + if all(r == 1 for r in reps): + return copy_op(x) + + interleaved = [] + for r, s in zip(reps, x_shape): + interleaved.append(1) + interleaved.append(s) + result = reshape_op(x, tuple(interleaved)) + + bcast_shape = list(result.shape) + for i, r in enumerate(reps): + bcast_shape[i * 2] = r + result = bcast_op(result, tuple(bcast_shape)) + + final_shape = tuple(r * s for r, s in zip(reps, x_shape)) + return reshape_op(result, final_shape) + + +def roll(x, shift, axis=None): + from nkipy.core.ops.transform import reshape as reshape_op + + x_shape = x.shape + ndim = len(x_shape) + + if axis is None: + total = int(np.prod(x_shape)) + flat = reshape_op(x, (total,)) + rolled = _roll_single_axis(flat, shift, 0) + return reshape_op(rolled, x_shape) + + if isinstance(shift, (list, tuple)): + if not isinstance(axis, (list, tuple)): + raise ValueError("If shift is a tuple, axis must also be a tuple") + result = x + for s, a in zip(shift, axis): + result = _roll_single_axis(result, s, a if a >= 0 else a + ndim) + return result + + if axis < 0: + axis += ndim + return _roll_single_axis(x, shift, axis) + + +def _roll_single_axis(x, shift, axis): + from nkipy.core.ops.transform import concatenate as concat_op + + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_bt = x.backend_tensor + else: + x_bt = x + + axis_size = x_bt.shape[axis] + ndim = len(x_bt.shape) + + shift = shift % axis_size + if shift == 0: + return NKIPyTensorRef(x_bt) if not isinstance(x, NKIPyTensorRef) else x + + split_point = axis_size - shift + + start1 = [0] * ndim + limit1 = list(x_bt.shape) + start1[axis] = split_point + shape1 = list(x_bt.shape) + shape1[axis] = shift + + t1 = ctx.build_op( + "slice", + [x_bt], + tuple(shape1), + x_bt.dtype, + {"start_indices": start1, "limit_indices": limit1, "strides": [1] * ndim}, + ) + + start0 = [0] * ndim + limit0 = list(x_bt.shape) + limit0[axis] = split_point + shape0 = list(x_bt.shape) + shape0[axis] = split_point + + t0 = ctx.build_op( + "slice", + [x_bt], + tuple(shape0), + x_bt.dtype, + {"start_indices": start0, "limit_indices": limit0, "strides": [1] * ndim}, + ) + + return concat_op([NKIPyTensorRef(t1), NKIPyTensorRef(t0)], axis=axis) + + +def _slice_single(x, dim, index): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_bt = x.backend_tensor + else: + x_bt = x + + ndim = len(x_bt.shape) + start_indices = [0] * ndim + limit_indices = list(x_bt.shape) + strides_list = [1] * ndim + + start_indices[dim] = index + limit_indices[dim] = index + 1 + + slice_shape = list(x_bt.shape) + slice_shape[dim] = 1 + + sliced = ctx.build_op( + "slice", + [x_bt], + tuple(slice_shape), + x_bt.dtype, + { + "start_indices": start_indices, + "limit_indices": limit_indices, + "strides": strides_list, + }, + ) + + result_shape = tuple(s for i, s in enumerate(x_bt.shape) if i != dim) + result = ctx.build_op("reshape", [sliced], result_shape, x_bt.dtype) + return NKIPyTensorRef(result) + + +# ============================================================================= +# Indexing ops +# ============================================================================= + + +def where(condition, x, y): + ctx = get_hlo_context() + + output_dtype = find_common_type_hlo(x, y) + + if isinstance(condition, NKIPyTensorRef): + condition = condition.backend_tensor + elif np.isscalar(condition): + condition = as_hlo_tensor(ctx, bool(condition), np.bool_) + elif isinstance(condition, np.ndarray): + condition = as_hlo_tensor(ctx, condition.astype(bool), np.bool_) + + if hasattr(condition, "dtype") and condition.dtype != np.bool_: + zero = as_hlo_tensor(ctx, 0, condition.dtype) + if condition.shape: + zero = ctx.build_op( + "broadcast", + [zero], + condition.shape, + condition.dtype, + {"broadcast_dimensions": []}, + ) + condition = ctx.build_op( + "compare", + [condition, zero], + condition.shape, + np.bool_, + {"comparison_direction": "NE"}, + ) + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + elif np.isscalar(x): + x = as_hlo_tensor(ctx, x, output_dtype) + elif isinstance(x, np.ndarray): + const_op = HLOOp( + "constant", + [], + result_shape=x.shape, + result_dtype=x.dtype, + attributes={"value": x}, + ) + x = ctx.module.add_operation(const_op) + + if isinstance(y, NKIPyTensorRef): + y = y.backend_tensor + elif np.isscalar(y): + y = as_hlo_tensor(ctx, y, output_dtype) + elif isinstance(y, np.ndarray): + const_op = HLOOp( + "constant", + [], + result_shape=y.shape, + result_dtype=y.dtype, + attributes={"value": y}, + ) + y = ctx.module.add_operation(const_op) + + broadcast_shape = tuple(np.broadcast_shapes(condition.shape, x.shape, y.shape)) + + if condition.shape != broadcast_shape: + condition = broadcast_to_shape_hlo(ctx, condition, broadcast_shape) + + if x.shape != broadcast_shape: + x = broadcast_to_shape_hlo(ctx, x, broadcast_shape) + + if y.shape != broadcast_shape: + y = broadcast_to_shape_hlo(ctx, y, broadcast_shape) + + result_tensor = ctx.build_op( + "select", [condition, x, y], broadcast_shape, output_dtype + ) + + return NKIPyTensorRef(result_tensor) + + +def take(x, indices, axis=None): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + if axis is None: + flattened_shape = (int(np.prod(x.shape)),) + x = ctx.build_op("reshape", [x], flattened_shape, x.dtype) + axis = 0 + + if axis < 0: + axis = len(x.shape) + axis + + dtype = x.dtype + + if isinstance(indices, NKIPyTensorRef): + indices_tensor = indices.backend_tensor + elif np.isscalar(indices): + if indices < 0: + indices = x.shape[axis] + indices + indices_tensor = as_hlo_tensor(ctx, int(indices), np.dtype(np.int32)) + elif isinstance(indices, (np.ndarray, list)): + if isinstance(indices, list): + indices_np = np.array(indices, dtype=np.int32) + else: + indices_np = indices.astype(np.int32) + const_op = HLOOp( + "constant", + [], + result_shape=indices_np.shape, + result_dtype=np.dtype(np.int32), + attributes={"value": indices_np}, + ) + indices_tensor = ctx.module.add_operation(const_op) + else: + raise ValueError( + "np.take only supports TensorRef, scalar, np.ndarray, or list as indices!" + ) + + indices_shape = indices_tensor.shape if hasattr(indices_tensor, "shape") else () + + output_shape = [] + for i in range(len(x.shape)): + if i == axis: + output_shape.extend(indices_shape) + else: + output_shape.append(x.shape[i]) + output_shape = tuple(output_shape) + + offset_dims = [] + for i in range(len(output_shape)): + if i < axis or i >= axis + len(indices_shape): + offset_dims.append(i) + + collapsed_slice_dims = [axis] + start_index_map = [axis] + index_vector_dim = len(indices_shape) + + slice_sizes = list(x.shape) + slice_sizes[axis] = 1 + + result_tensor = ctx.build_op( + "gather", + [x, indices_tensor], + output_shape, + dtype, + { + "offset_dims": offset_dims, + "collapsed_slice_dims": collapsed_slice_dims, + "start_index_map": start_index_map, + "index_vector_dim": index_vector_dim, + "slice_sizes": slice_sizes, + "indices_are_sorted": False, + }, + ) + + return NKIPyTensorRef(result_tensor) + + +def take_along_axis(x, indices, axis): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + if axis is None: + flattened_shape = (int(np.prod(x.shape)),) + x = ctx.build_op("reshape", [x], flattened_shape, x.dtype) + axis = 0 + + if axis < 0: + axis = len(x.shape) + axis + + if isinstance(indices, NKIPyTensorRef): + indices_tensor = indices.backend_tensor + elif isinstance(indices, np.ndarray): + indices_np = indices.astype(np.int32) + const_op = HLOOp( + "constant", + [], + result_shape=indices_np.shape, + result_dtype=np.dtype(np.int32), + attributes={"value": indices_np}, + ) + indices_tensor = ctx.module.add_operation(const_op) + else: + raise ValueError( + "take_along_axis only supports TensorRef or np.ndarray as indices!" + ) + + data_rank = len(x.shape) + + if indices_tensor.dtype != np.dtype(np.int32): + indices_tensor = ctx.build_op( + "convert", [indices_tensor], indices_tensor.shape, np.dtype(np.int32) + ) + + target_indices_shape = list(x.shape) + target_indices_shape[axis] = ( + indices_tensor.shape[axis] if axis < len(indices_tensor.shape) else 1 + ) + target_indices_shape = tuple(target_indices_shape) + + if indices_tensor.shape != target_indices_shape: + indices_tensor = broadcast_to_shape_hlo( + ctx, indices_tensor, target_indices_shape + ) + + index_arrays = [] + for i in range(data_rank): + if i == axis: + index_arrays.append(indices_tensor) + else: + arange_shape = [1] * data_rank + arange_shape[i] = x.shape[i] + arange_shape = tuple(arange_shape) + + arange_vals = np.arange(x.shape[i], dtype=np.int32) + const_op = HLOOp( + "constant", + [], + result_shape=(x.shape[i],), + result_dtype=np.dtype(np.int32), + attributes={"value": arange_vals}, + ) + arange_tensor = ctx.module.add_operation(const_op) + + arange_tensor = ctx.build_op( + "reshape", [arange_tensor], arange_shape, np.dtype(np.int32) + ) + index_arrays.append(arange_tensor) + + broadcast_shape = target_indices_shape + broadcasted_indices = [] + for idx_array in index_arrays: + if idx_array.shape != broadcast_shape: + broadcasted = broadcast_to_shape_hlo(ctx, idx_array, broadcast_shape) + broadcasted_indices.append(broadcasted) + else: + broadcasted_indices.append(idx_array) + + reshaped_indices = [] + for idx in broadcasted_indices: + new_shape = idx.shape + (1,) + reshaped = ctx.build_op("reshape", [idx], new_shape, np.dtype(np.int32)) + reshaped_indices.append(reshaped) + + stacked_shape = broadcast_shape + (data_rank,) + gather_indices = ctx.build_op( + "concatenate", + reshaped_indices, + stacked_shape, + np.dtype(np.int32), + {"dimension": data_rank}, + ) + + dnums = { + "offset_dims": [], + "collapsed_slice_dims": list(range(data_rank)), + "start_index_map": list(range(data_rank)), + "index_vector_dim": data_rank, + "slice_sizes": [1] * data_rank, + "indices_are_sorted": False, + } + + result_tensor = ctx.build_op( + "gather", + [x, gather_indices], + broadcast_shape, + x.dtype, + dnums, + ) + + return NKIPyTensorRef(result_tensor) + + +def scatter_along_axis(x, indices, values, axis): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + x_copy = ctx.build_op("copy", [x], x.shape, x.dtype) + + if axis < 0: + axis = len(x.shape) + axis + + if isinstance(indices, NKIPyTensorRef): + indices_tensor = indices.backend_tensor + elif isinstance(indices, np.ndarray): + indices_np = indices.astype(np.int32) + const_op = HLOOp( + "constant", + [], + result_shape=indices_np.shape, + result_dtype=np.dtype(np.int32), + attributes={"value": indices_np}, + ) + indices_tensor = ctx.module.add_operation(const_op) + else: + raise ValueError("scatter_along_axis requires TensorRef or np.ndarray indices") + + if isinstance(values, NKIPyTensorRef): + values_tensor = values.backend_tensor + elif isinstance(values, np.ndarray): + values_np = values.astype(x.dtype) + const_op = HLOOp( + "constant", + [], + result_shape=values_np.shape, + result_dtype=x.dtype, + attributes={"value": values_np}, + ) + values_tensor = ctx.module.add_operation(const_op) + else: + values_tensor = as_hlo_tensor(ctx, values, x.dtype) + + update_window_dims = [i for i in range(len(x_copy.shape)) if i != axis] + scattered_tensor = ctx.build_op( + "scatter", + [x_copy, indices_tensor, values_tensor], + x_copy.shape, + x.dtype, + { + "update_window_dims": update_window_dims, + "inserted_window_dims": [axis], + "scatter_dims_to_operand_dims": [axis], + "index_vector_dim": len(indices_tensor.shape), + "update_computation": "assign", + "indices_are_sorted": False, + "unique_indices": False, + }, + ) + + return NKIPyTensorRef(scattered_tensor) + + +def put_along_axis(x, indices, values, axis): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + x_shape = x.shape + x_dtype = x.dtype + x_copy = ctx.build_op("copy", [x], x_shape, x_dtype) + + if axis is None: + axis = 0 + effective_shape = (int(np.prod(x_shape)),) + else: + if axis < 0: + axis = len(x_shape) + axis + effective_shape = x_shape + + if isinstance(indices, NKIPyTensorRef): + indices_tensor = indices.backend_tensor + elif isinstance(indices, np.ndarray): + indices_np = indices.astype(np.int32) + const_op = HLOOp( + "constant", + [], + result_shape=indices_np.shape, + result_dtype=np.dtype(np.int32), + attributes={"value": indices_np}, + ) + indices_tensor = ctx.module.add_operation(const_op) + else: + raise ValueError( + "put_along_axis only supports TensorRef or np.ndarray as indices!" + ) + + if indices_tensor.dtype != np.dtype(np.int32): + indices_tensor = ctx.build_op( + "convert", [indices_tensor], indices_tensor.shape, np.dtype(np.int32) + ) + + idx_shape = indices_tensor.shape + + if np.isscalar(values): + scalar_tensor = as_hlo_tensor(ctx, values, x_dtype) + if idx_shape: + values_tensor = ctx.build_op( + "broadcast", + [scalar_tensor], + idx_shape, + x_dtype, + {"broadcast_dimensions": []}, + ) + else: + values_tensor = scalar_tensor + elif isinstance(values, NKIPyTensorRef): + values_tensor = values.backend_tensor + elif isinstance(values, np.ndarray): + values_np = values.astype(x_dtype) + const_op = HLOOp( + "constant", + [], + result_shape=values_np.shape, + result_dtype=x_dtype, + attributes={"value": values_np}, + ) + values_tensor = ctx.module.add_operation(const_op) + else: + raise ValueError( + "put_along_axis only supports scalar, TensorRef, or np.ndarray as values!" + ) + + if values_tensor.shape != idx_shape: + values_tensor = ctx.build_op( + "reshape", [values_tensor], idx_shape, values_tensor.dtype + ) + + ndim = len(effective_shape) + strides = [1] * ndim + for d in range(ndim - 2, -1, -1): + strides[d] = strides[d + 1] * effective_shape[d + 1] + + offset_np = np.zeros(idx_shape, dtype=np.int32) + for d in range(ndim): + if d == axis: + continue + coord = np.arange(idx_shape[d], dtype=np.int32) + bcast = [1] * len(idx_shape) + bcast[d] = idx_shape[d] + offset_np = offset_np + coord.reshape(bcast) * strides[d] + + offset_const = ctx.build_op( + "constant", [], idx_shape, np.dtype(np.int32), {"value": offset_np} + ) + + axis_stride_scalar = ctx.build_op( + "constant", + [], + (), + np.dtype(np.int32), + {"value": np.int32(strides[axis])}, + ) + axis_stride = ctx.build_op( + "broadcast", + [axis_stride_scalar], + idx_shape, + np.dtype(np.int32), + {"broadcast_dimensions": []}, + ) + scaled = ctx.build_op( + "multiply", + [indices_tensor, axis_stride], + idx_shape, + np.dtype(np.int32), + ) + flat_indices = ctx.build_op( + "add", + [scaled, offset_const], + idx_shape, + np.dtype(np.int32), + ) + + flat_size = int(np.prod(effective_shape)) + num_elements = int(np.prod(idx_shape)) + + x_flat = ctx.build_op("reshape", [x_copy], (flat_size,), x_dtype) + flat_indices_1d = ctx.build_op( + "reshape", + [flat_indices], + (num_elements,), + np.dtype(np.int32), + ) + flat_values_1d = ctx.build_op( + "reshape", + [values_tensor], + (num_elements,), + x_dtype, + ) + + scattered = ctx.build_op( + "scatter", + [x_flat, flat_indices_1d, flat_values_1d], + (flat_size,), + x_dtype, + { + "update_window_dims": [], + "inserted_window_dims": [0], + "scatter_dims_to_operand_dims": [0], + "index_vector_dim": 1, + "update_computation": "assign", + "indices_are_sorted": False, + "unique_indices": False, + }, + ) + + result_tensor = ctx.build_op("reshape", [scattered], x_shape, x_dtype) + return NKIPyTensorRef(result_tensor) + + +def static_slice( + x, + start_indices: List[int], + limit_indices: List[int], + strides: List[int], + squeeze_dims: List[int], +): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + backend_tensor = x.backend_tensor + dtype = x.dtype + else: + backend_tensor = x + dtype = x.dtype + + slice_shape = [] + for start, limit, stride in zip(start_indices, limit_indices, strides): + size = (limit - start + stride - 1) // stride + slice_shape.append(size) + + result_tensor = ctx.build_op( + "slice", + [backend_tensor], + tuple(slice_shape), + dtype, + { + "start_indices": start_indices, + "limit_indices": limit_indices, + "strides": strides, + }, + ) + + if squeeze_dims: + output_shape = [s for i, s in enumerate(slice_shape) if i not in squeeze_dims] + final_shape = tuple(output_shape) if output_shape else () + result_tensor = ctx.build_op("reshape", [result_tensor], final_shape, dtype) + + return NKIPyTensorRef(result_tensor) + + +def dynamic_update_slice( + x, + value, + start_indices: List[int], + update_shape: Tuple[int, ...], +): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_tensor = x.backend_tensor + x_shape = x.shape + x_dtype = x.dtype + else: + x_tensor = x + x_shape = x.shape + x_dtype = x.dtype + + if isinstance(value, NKIPyTensorRef): + value_tensor = value.backend_tensor + elif isinstance(value, (int, float)): + value_array = np.full(update_shape, value, dtype=x_dtype) + value_tensor = ctx.build_op( + "constant", [], tuple(update_shape), x_dtype, {"value": value_array} + ) + elif isinstance(value, np.ndarray): + value_tensor = ctx.build_op( + "constant", [], value.shape, value.dtype, {"value": value} + ) + else: + value_tensor = value + + if value_tensor.shape != update_shape: + value_tensor = ctx.build_op( + "reshape", [value_tensor], update_shape, value_tensor.dtype + ) + + start_index_tensors = [] + for start_idx in start_indices: + scalar_tensor = ctx.build_op("constant", [], (), np.int32, {"value": start_idx}) + start_index_tensors.append(scalar_tensor) + + result_tensor = ctx.build_op( + "dynamic-update-slice", + [x_tensor, value_tensor] + start_index_tensors, + x_shape, + x_dtype, + {}, + ) + + return NKIPyTensorRef(result_tensor) + + +def scatter_strided( + x, + value, + scatter_indices_per_dim: List[List[int]], +): + ctx = get_hlo_context() + + if isinstance(x, NKIPyTensorRef): + x_tensor = x.backend_tensor + x_shape = x.shape + x_dtype = x.dtype + else: + x_tensor = x + x_shape = x.shape + x_dtype = x.dtype + + value_shape = tuple(len(indices) for indices in scatter_indices_per_dim) + + if isinstance(value, NKIPyTensorRef): + value_tensor = value.backend_tensor + elif isinstance(value, (int, float)): + value_array = np.full(value_shape, value, dtype=x_dtype) + value_tensor = ctx.build_op( + "constant", [], value_shape, x_dtype, {"value": value_array} + ) + elif isinstance(value, np.ndarray): + value_tensor = ctx.build_op( + "constant", [], value.shape, value.dtype, {"value": value} + ) + else: + value_tensor = value + + all_positions = list(itertools.product(*scatter_indices_per_dim)) + scatter_indices_array = np.array(all_positions, dtype=np.int32) + + indices_tensor = ctx.build_op( + "constant", + [], + scatter_indices_array.shape, + np.dtype(np.int32), + {"value": scatter_indices_array}, + ) + + flat_value_shape = (scatter_indices_array.shape[0],) + flat_value = ctx.build_op( + "reshape", [value_tensor], flat_value_shape, value_tensor.dtype + ) + + update_window_dims = [] + inserted_window_dims = list(range(len(x_shape))) + scatter_dims_to_operand_dims = list(range(len(x_shape))) + index_vector_dim = 1 + + result_tensor = ctx.build_op( + "scatter", + [x_tensor, indices_tensor, flat_value], + x_shape, + x_dtype, + { + "update_window_dims": update_window_dims, + "inserted_window_dims": inserted_window_dims, + "scatter_dims_to_operand_dims": scatter_dims_to_operand_dims, + "index_vector_dim": index_vector_dim, + "update_computation": "assign", + "indices_are_sorted": False, + "unique_indices": True, + }, + ) + + return NKIPyTensorRef(result_tensor) + + +# ============================================================================= +# NN ops +# ============================================================================= + + +def topk(x, k, axis=0, is_ascend=False, out=None, dtype=None): + ctx = get_hlo_context() + + if axis != -1 and axis != x.ndim - 1: + raise NotImplementedError("the custom TopK op only supports last axis") + + if isinstance(x, NKIPyTensorRef): + x = x.backend_tensor + + if axis < 0: + axis = len(x.shape) + axis + + assert x.shape[axis] >= k, ( + f"k={k} must be <= size of axis {axis} which is {x.shape[axis]}" + ) + + output_shape = list(x.shape) + output_shape[axis] = k + output_shape = tuple(output_shape) + + input_for_topk = x + if is_ascend: + input_for_topk = ctx.build_op( + "negate", [input_for_topk], input_for_topk.shape, input_for_topk.dtype + ) + + topk_output_shape = list(input_for_topk.shape) + topk_output_shape[-1] = k + topk_output_shape = tuple(topk_output_shape) + + topk_tuple = ctx.build_op( + "topk", + [input_for_topk], + topk_output_shape, + x.dtype, + {"k": k, "largest": True, "is_tuple": True}, + ) + + values_tensor = ctx.build_op( + "get-tuple-element", + [topk_tuple], + topk_output_shape, + x.dtype, + {"tuple_index": 0}, + ) + + indices_tensor = ctx.build_op( + "get-tuple-element", + [topk_tuple], + topk_output_shape, + np.dtype(np.uint32), + {"tuple_index": 1}, + ) + + if is_ascend: + values_tensor = ctx.build_op( + "negate", [values_tensor], topk_output_shape, x.dtype + ) + + return NKIPyTensorRef(values_tensor), NKIPyTensorRef(indices_tensor) + + +# ============================================================================= +# Collective ops +# ============================================================================= + + +def all_gather(data, all_gather_dim, replica_groups, **kwargs): + ctx = get_hlo_context() + + rank = len(replica_groups[0]) + out_shape = list(data.shape) + if out_shape: + out_shape[all_gather_dim] *= rank + + result_tensor = ctx.build_op( + "all-gather", + [data.backend_tensor], + tuple(out_shape), + data.dtype, + { + "all_gather_dim": all_gather_dim, + "replica_groups": replica_groups, + }, + ) + return NKIPyTensorRef(result_tensor) + + +def all_reduce(data, replica_groups, reduce_op=np.add, **kwargs): + ctx = get_hlo_context() + + reduce_op_map = { + np.add: "add", + np.multiply: "multiply", + np.maximum: "maximum", + np.minimum: "minimum", + } + reduce_op_str = reduce_op_map.get(reduce_op, "add") + + result_tensor = ctx.build_op( + "all-reduce", + [data.backend_tensor], + data.shape, + data.dtype, + { + "replica_groups": replica_groups, + "reduce_op": reduce_op_str, + }, + ) + return NKIPyTensorRef(result_tensor) + + +def reduce_scatter(data, reduce_scatter_dim: int, replica_groups, reduce_op=np.add, **kwargs): + ctx = get_hlo_context() + rank = len(replica_groups[0]) + out_shape = list(data.shape) + if out_shape: + out_shape[reduce_scatter_dim] //= rank + + reduce_op_map = { + np.add: "add", + np.multiply: "multiply", + np.maximum: "maximum", + np.minimum: "minimum", + } + reduce_op_str = reduce_op_map.get(reduce_op, "add") + + result_tensor = ctx.build_op( + "reduce-scatter", + [data.backend_tensor], + tuple(out_shape), + data.dtype, + { + "reduce_scatter_dim": reduce_scatter_dim, + "replica_groups": replica_groups, + "reduce_op": reduce_op_str, + }, + ) + return NKIPyTensorRef(result_tensor) + + +def all_to_all(data, split_dimension: int, concat_dimension: int, replica_groups, **kwargs): + ctx = get_hlo_context() + result_tensor = ctx.build_op( + "all-to-all", + [data.backend_tensor], + data.shape, + data.dtype, + { + "split_dimension": split_dimension, + "concat_dimension": concat_dimension, + "replica_groups": replica_groups, + }, + ) + return NKIPyTensorRef(result_tensor) + + +# ============================================================================= +# Convolution ops +# ============================================================================= + + +def conv2d( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + out=None, + dtype=None, +): + from nkipy.core.ops.conv import _normalize_tuple_2d + + ctx = get_hlo_context() + + if isinstance(input, NKIPyTensorRef): + input = input.backend_tensor + if isinstance(weight, NKIPyTensorRef): + weight = weight.backend_tensor + + stride = _normalize_tuple_2d(stride, "stride") + dilation = _normalize_tuple_2d(dilation, "dilation") + padding_tuple = _normalize_tuple_2d(padding, "padding") + + batch_size, in_channels, in_height, in_width = input.shape + out_channels, _, kernel_height, kernel_width = weight.shape + + out_height = ( + in_height + 2 * padding_tuple[0] - dilation[0] * (kernel_height - 1) - 1 + ) // stride[0] + 1 + out_width = ( + in_width + 2 * padding_tuple[1] - dilation[1] * (kernel_width - 1) - 1 + ) // stride[1] + 1 + + output_shape = (batch_size, out_channels, out_height, out_width) + + padding_config = [ + (padding_tuple[0], padding_tuple[0]), + (padding_tuple[1], padding_tuple[1]), + ] + + result_tensor = ctx.build_op( + "convolution", + [input, weight], + output_shape, + input.dtype, + { + "window_strides": list(stride), + "padding": padding_config, + "lhs_dilation": [1, 1], + "rhs_dilation": list(dilation), + "feature_group_count": groups, + "batch_group_count": 1, + "input_batch_dimension": 0, + "input_feature_dimension": 1, + "input_spatial_dimensions": [2, 3], + "kernel_output_feature_dimension": 0, + "kernel_input_feature_dimension": 1, + "kernel_spatial_dimensions": [2, 3], + "output_batch_dimension": 0, + "output_feature_dimension": 1, + "output_spatial_dimensions": [2, 3], + }, + ) + + if bias is not None: + if isinstance(bias, NKIPyTensorRef): + bias = bias.backend_tensor + + bias_reshaped = ctx.build_op( + "reshape", [bias], (1, out_channels, 1, 1), bias.dtype + ) + bias_broadcast = broadcast_to_shape_hlo(ctx, bias_reshaped, output_shape) + result_tensor = ctx.build_op( + "add", [result_tensor, bias_broadcast], output_shape, input.dtype + ) + + return NKIPyTensorRef(result_tensor) + + +def conv3d( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + out=None, + dtype=None, +): + from nkipy.core.ops.conv import _normalize_tuple_3d + + ctx = get_hlo_context() + + if isinstance(input, NKIPyTensorRef): + input = input.backend_tensor + if isinstance(weight, NKIPyTensorRef): + weight = weight.backend_tensor + + stride = _normalize_tuple_3d(stride, "stride") + dilation = _normalize_tuple_3d(dilation, "dilation") + padding_tuple = _normalize_tuple_3d(padding, "padding") + + batch_size, in_channels, in_depth, in_height, in_width = input.shape + out_channels, _, kernel_depth, kernel_height, kernel_width = weight.shape + + out_depth = ( + in_depth + 2 * padding_tuple[0] - dilation[0] * (kernel_depth - 1) - 1 + ) // stride[0] + 1 + out_height = ( + in_height + 2 * padding_tuple[1] - dilation[1] * (kernel_height - 1) - 1 + ) // stride[1] + 1 + out_width = ( + in_width + 2 * padding_tuple[2] - dilation[2] * (kernel_width - 1) - 1 + ) // stride[2] + 1 + + output_shape = (batch_size, out_channels, out_depth, out_height, out_width) + + padding_config = [ + (padding_tuple[0], padding_tuple[0]), + (padding_tuple[1], padding_tuple[1]), + (padding_tuple[2], padding_tuple[2]), + ] + + result_tensor = ctx.build_op( + "convolution", + [input, weight], + output_shape, + input.dtype, + { + "window_strides": list(stride), + "padding": padding_config, + "lhs_dilation": [1, 1, 1], + "rhs_dilation": list(dilation), + "feature_group_count": groups, + "batch_group_count": 1, + "input_batch_dimension": 0, + "input_feature_dimension": 1, + "input_spatial_dimensions": [2, 3, 4], + "kernel_output_feature_dimension": 0, + "kernel_input_feature_dimension": 1, + "kernel_spatial_dimensions": [2, 3, 4], + "output_batch_dimension": 0, + "output_feature_dimension": 1, + "output_spatial_dimensions": [2, 3, 4], + }, + ) + + if bias is not None: + if isinstance(bias, NKIPyTensorRef): + bias = bias.backend_tensor + + bias_reshaped = ctx.build_op( + "reshape", [bias], (1, out_channels, 1, 1, 1), bias.dtype + ) + bias_broadcast = broadcast_to_shape_hlo(ctx, bias_reshaped, output_shape) + result_tensor = ctx.build_op( + "add", [result_tensor, bias_broadcast], output_shape, input.dtype + ) + + return NKIPyTensorRef(result_tensor) diff --git a/nkipy/src/nkipy/core/ops/_nkigen_impls.py b/nkipy/src/nkipy/core/ops/_nkigen_impls.py new file mode 100644 index 0000000..3a91fe2 --- /dev/null +++ b/nkipy/src/nkipy/core/ops/_nkigen_impls.py @@ -0,0 +1,284 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""NkiGen backend implementations for NKIPy ops. + +Trivial ops use _unary/_binary factories that delegate to the builder. +Non-trivial ops with custom logic are explicit functions below. +""" + +from __future__ import annotations + +import numpy as np + +from nkipy.core.tensor import NKIPyTensorRef +from nkipy.core.backend.nkigen import NkiGenTensor + +_builder_module = None + + +def _builder(): + global _builder_module + if _builder_module is None: + import nkigen.builder as _mod + _builder_module = _mod + return _builder_module + + +def _unwrap(x): + if isinstance(x, NKIPyTensorRef): + return x.backend_tensor.handle + return x + + +def _wrap(handle): + kt = NkiGenTensor(handle, handle.shape, handle.dtype) + return NKIPyTensorRef(kt) + + +# --------------------------------------------------------------------------- +# Factories for trivial delegation to builder +# --------------------------------------------------------------------------- + +def _unary(method): + def impl(x, out=None, dtype=None): + return _wrap(getattr(_builder(), method)(_unwrap(x))) + return impl + + +def _binary(method): + def impl(x, y, out=None, dtype=None): + return _wrap(getattr(_builder(), method)(_unwrap(x), _unwrap(y))) + return impl + + +def _reduce(method): + def impl(x, axis=None, keepdims=False, **kwargs): + return _wrap(getattr(_builder(), method)(_unwrap(x), axis=axis, keepdims=keepdims)) + return impl + + +# Binary ops +add = _binary("add") +subtract = _binary("subtract") +multiply = _binary("multiply") +divide = _binary("divide") +power = _binary("power") +maximum = _binary("maximum") +minimum = _binary("minimum") +equal = _binary("equal") +not_equal = _binary("not_equal") +greater = _binary("greater") +greater_equal = _binary("greater_equal") +less = _binary("less") +less_equal = _binary("less_equal") +bitwise_and = _binary("bitwise_and") +bitwise_or = _binary("bitwise_or") +bitwise_xor = _binary("bitwise_xor") +matmul = _binary("matmul") + +# Unary ops +exp = _unary("exp") +log = _unary("log") +sqrt = _unary("sqrt") +tanh = _unary("tanh") +sin = _unary("sin") +cos = _unary("cos") +sign = _unary("sign") +abs = _unary("abs_") +ceil = _unary("ceil_") +floor = _unary("floor_") + +# Reductions +reduce_sum = _reduce("reduce_sum") +reduce_prod = _reduce("reduce_prod") +reduce_max = _reduce("reduce_max") +reduce_min = _reduce("reduce_min") +reduce_mean = _reduce("reduce_mean") +reduce_std = _reduce("reduce_std") +reduce_var = _reduce("reduce_var") + + +# --------------------------------------------------------------------------- +# Composed unary ops +# --------------------------------------------------------------------------- + +def negative(x, out=None, dtype=None): + return _wrap(_builder().subtract(_unwrap(0), _unwrap(x))) + + +def reciprocal(x, out=None, dtype=None): + return _wrap(_builder().divide(_unwrap(1.0), _unwrap(x))) + + +def square(x, out=None, dtype=None): + h = _unwrap(x) + return _wrap(_builder().multiply(h, h)) + + +def logical_not(x, out=None, dtype=None): + return _wrap(_builder().subtract(_unwrap(1), _unwrap(x))) + + +# --------------------------------------------------------------------------- +# Transform ops with custom signatures +# --------------------------------------------------------------------------- + +def transpose(x, axes=None): + return _wrap(_builder().transpose(_unwrap(x), axes=axes)) + + +def reshape(x, newshape, order='C'): + return _wrap(_builder().reshape(_unwrap(x), newshape)) + + +def expand_dims(x, axis): + return _wrap(_builder().expand_dims(_unwrap(x), axis)) + + +def copy(x, order='K', subok=True): + return _wrap(_builder().copy_(_unwrap(x))) + + +def broadcast_to(x, shape): + return _wrap(_builder().broadcast_to(_unwrap(x), tuple(shape))) + + +def astype(x, dtype): + return _wrap(_builder().astype(_unwrap(x), dtype)) + + +def concatenate(arrays, axis=0, out=None, dtype=None): + handles = [_unwrap(a) for a in arrays] + return _wrap(_builder().concatenate(handles, axis=axis)) + + +def where(condition, x, y): + return _wrap(_builder().where(_unwrap(condition), _unwrap(x), _unwrap(y))) + + +def take(a, indices, axis=0): + return _wrap(_builder().take(_unwrap(a), _unwrap(indices), axis=axis)) + + +# --------------------------------------------------------------------------- +# Creation ops +# --------------------------------------------------------------------------- + +def zeros(shape, dtype=np.float32): + return _wrap(_builder().zeros(tuple(shape), dtype)) + + +def full(shape, fill_value, dtype=np.float32): + return _wrap(_builder().full(tuple(shape), fill_value, dtype)) + + +def zeros_like(x, dtype=None): + h = _unwrap(x) + dt = dtype if dtype is not None else h.dtype + return zeros(h.shape, dt) + + +def ones_like(x, dtype=None): + h = _unwrap(x) + dt = dtype if dtype is not None else h.dtype + return full(h.shape, 1.0, dt) + + +def empty_like(x, dtype=None): + h = _unwrap(x) + dt = dtype if dtype is not None else h.dtype + return _wrap(_builder().empty(h.shape, dt)) + + +def full_like(x, fill_value, dtype=None): + h = _unwrap(x) + dt = dtype if dtype is not None else h.dtype + return full(h.shape, fill_value, dt) + + +# --------------------------------------------------------------------------- +# Squeeze / swapaxes / stack / split +# --------------------------------------------------------------------------- + +def squeeze(x, axis=None): + h = _unwrap(x) + shape = h.shape + if axis is None: + new_shape = tuple(d for d in shape if d != 1) + else: + if isinstance(axis, int): + axis = (axis,) + new_shape = tuple(d for i, d in enumerate(shape) if i not in axis) + if new_shape == shape: + return x + return reshape(x, new_shape) + + +def swapaxes(x, axis1, axis2): + h = _unwrap(x) + rank = len(h.shape) + perm = list(range(rank)) + perm[axis1], perm[axis2] = perm[axis2], perm[axis1] + return transpose(x, axes=perm) + + +def stack(arrays, axis=0): + expanded = [expand_dims(a, axis) for a in arrays] + return concatenate(expanded, axis=axis) + + +def split(x, indices_or_sections, axis=0): + h = _unwrap(x) + shape = h.shape + if isinstance(indices_or_sections, int): + sections = indices_or_sections + size = shape[axis] + section_size = size // sections + results = [] + for i in range(sections): + start = [0] * len(shape) + start[axis] = i * section_size + limit = list(shape) + limit[axis] = (i + 1) * section_size + strides = [1] * len(shape) + results.append(static_slice(x, start, limit, strides, [])) + return tuple(results) + raise NotImplementedError("split with explicit indices not yet implemented") + + +# --------------------------------------------------------------------------- +# Static slicing +# --------------------------------------------------------------------------- + +def static_slice(x, start_indices, limit_indices, strides, squeeze_dims): + return _wrap(_builder().static_slice( + _unwrap(x), start_indices, limit_indices, strides, squeeze_dims, + )) + + +# --------------------------------------------------------------------------- +# Slice assignment (dynamic_update_slice) +# --------------------------------------------------------------------------- + +def dynamic_update_slice(x, value, start_indices, update_shape): + b = _builder() + x_h = _unwrap(x) + if isinstance(value, NKIPyTensorRef): + value_h = _unwrap(value) + elif isinstance(value, (int, float)): + value_h = b.full(tuple(update_shape), value, x_h.dtype) + elif isinstance(value, np.ndarray): + raise NotImplementedError( + "Assigning a raw np.ndarray constant in nkigen is not supported. " + "Use a traced tensor expression instead." + ) + else: + value_h = value + + if value_h.shape != tuple(update_shape): + value_h = b.reshape(value_h, tuple(update_shape)) + + sizes = list(update_shape) + strides = [1] * len(start_indices) + result_h = b.static_insert_slice(x_h, value_h, start_indices, sizes, strides) + return _wrap(result_h) diff --git a/nkipy/src/nkipy/core/ops/_register_hlo.py b/nkipy/src/nkipy/core/ops/_register_hlo.py new file mode 100644 index 0000000..c15adf4 --- /dev/null +++ b/nkipy/src/nkipy/core/ops/_register_hlo.py @@ -0,0 +1,160 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Register HLO backend implementations for all ops. + +Called lazily the first time the HLO backend is activated, so HLO-specific +imports only happen when needed. + +Composed ops (floor_divide, tan, rint, etc.) are registered as +``composed_impl`` on the Op itself — they need no per-backend registration +since they dispatch through other ops. +""" + +_registered = False + + +def register_all_hlo_impls(): + global _registered + if _registered: + return + _registered = True + + from nkipy.core.ops import _hlo_impls as hlo_impls + + # --- Binary ops (primitives) --- + from nkipy.core.ops.binary import ( + add, subtract, multiply, divide, power, maximum, minimum, + bitwise_and, bitwise_or, bitwise_xor, + equal, not_equal, greater, greater_equal, less, less_equal, + logical_and, logical_or, logical_xor, + ) + add.impl("hlo")(hlo_impls.add) + subtract.impl("hlo")(hlo_impls.subtract) + multiply.impl("hlo")(hlo_impls.multiply) + divide.impl("hlo")(hlo_impls.divide) + power.impl("hlo")(hlo_impls.power) + maximum.impl("hlo")(hlo_impls.maximum) + minimum.impl("hlo")(hlo_impls.minimum) + bitwise_and.impl("hlo")(hlo_impls.bitwise_and) + bitwise_or.impl("hlo")(hlo_impls.bitwise_or) + bitwise_xor.impl("hlo")(hlo_impls.bitwise_xor) + equal.impl("hlo")(hlo_impls.equal) + not_equal.impl("hlo")(hlo_impls.not_equal) + greater.impl("hlo")(hlo_impls.greater) + greater_equal.impl("hlo")(hlo_impls.greater_equal) + less.impl("hlo")(hlo_impls.less) + less_equal.impl("hlo")(hlo_impls.less_equal) + logical_and.impl("hlo")(hlo_impls.logical_and) + logical_or.impl("hlo")(hlo_impls.logical_or) + logical_xor.impl("hlo")(hlo_impls.logical_xor) + + # --- Unary ops (primitives) --- + # Note: reciprocal, square, logical_not use composed_impl (dispatch through + # other ops) so they don't need backend-specific registration. + from nkipy.core.ops.unary import ( + abs, exp, log, sqrt, sin, cos, tanh, ceil, floor, sign, + negative, arctan, invert, bitwise_not, + ) + abs.impl("hlo")(hlo_impls.abs) + exp.impl("hlo")(hlo_impls.exp) + log.impl("hlo")(hlo_impls.log) + sqrt.impl("hlo")(hlo_impls.sqrt) + sin.impl("hlo")(hlo_impls.sin) + cos.impl("hlo")(hlo_impls.cos) + tanh.impl("hlo")(hlo_impls.tanh) + ceil.impl("hlo")(hlo_impls.ceil) + floor.impl("hlo")(hlo_impls.floor) + sign.impl("hlo")(hlo_impls.sign) + negative.impl("hlo")(hlo_impls.negative) + arctan.impl("hlo")(hlo_impls.arctan) + invert.impl("hlo")(hlo_impls.invert) + bitwise_not.impl("hlo")(hlo_impls.bitwise_not) + + # --- Reduction ops (primitives) --- + from nkipy.core.ops.reduce import sum, prod, max, min, argmax, argmin + sum.impl("hlo")(hlo_impls.reduce_sum) + prod.impl("hlo")(hlo_impls.reduce_prod) + max.impl("hlo")(hlo_impls.reduce_max) + min.impl("hlo")(hlo_impls.reduce_min) + argmax.impl("hlo")(hlo_impls.argmax) + argmin.impl("hlo")(hlo_impls.argmin) + + # --- Linalg ops --- + from nkipy.core.ops.linalg import matmul, dot, trace + matmul.impl("hlo")(hlo_impls.matmul) + dot.impl("hlo")(hlo_impls.dot) + trace.impl("hlo")(hlo_impls.trace) + + # --- Creation ops --- + from nkipy.core.ops.creation import ( + zeros as zeros_op, full as full_op, constant, + zeros_like, ones_like, empty_like, full_like, + tril, triu, diag, + ) + zeros_op.impl("hlo")(hlo_impls.zeros) + full_op.impl("hlo")(hlo_impls.full) + constant.impl("hlo")(hlo_impls.constant) + zeros_like.impl("hlo")(hlo_impls.zeros_like) + ones_like.impl("hlo")(hlo_impls.ones_like) + empty_like.impl("hlo")(hlo_impls.empty_like) + full_like.impl("hlo")(hlo_impls.full_like) + tril.impl("hlo")(hlo_impls.tril) + triu.impl("hlo")(hlo_impls.triu) + diag.impl("hlo")(hlo_impls.diag) + + # --- Transform ops --- + from nkipy.core.ops.transform import ( + reshape, transpose, expand_dims, concatenate, split, + copy, repeat, broadcast_to, astype, squeeze, pad, + swapaxes, stack, diff, flip, tile, roll, + ) + reshape.impl("hlo")(hlo_impls.reshape) + transpose.impl("hlo")(hlo_impls.transpose) + expand_dims.impl("hlo")(hlo_impls.expand_dims) + concatenate.impl("hlo")(hlo_impls.concatenate) + split.impl("hlo")(hlo_impls.split) + copy.impl("hlo")(hlo_impls.copy) + repeat.impl("hlo")(hlo_impls.repeat) + broadcast_to.impl("hlo")(hlo_impls.broadcast_to) + astype.impl("hlo")(hlo_impls.astype) + squeeze.impl("hlo")(hlo_impls.squeeze) + pad.impl("hlo")(hlo_impls.pad) + swapaxes.impl("hlo")(hlo_impls.swapaxes) + stack.impl("hlo")(hlo_impls.stack) + diff.impl("hlo")(hlo_impls.diff) + flip.impl("hlo")(hlo_impls.flip) + tile.impl("hlo")(hlo_impls.tile) + roll.impl("hlo")(hlo_impls.roll) + + # --- Indexing ops --- + from nkipy.core.ops.indexing import ( + where as where_op, take as take_op, + take_along_axis, put_along_axis, scatter_along_axis, + static_slice, dynamic_update_slice, scatter_strided, + ) + where_op.impl("hlo")(hlo_impls.where) + take_op.impl("hlo")(hlo_impls.take) + take_along_axis.impl("hlo")(hlo_impls.take_along_axis) + put_along_axis.impl("hlo")(hlo_impls.put_along_axis) + scatter_along_axis.impl("hlo")(hlo_impls.scatter_along_axis) + static_slice.impl("hlo")(hlo_impls.static_slice) + dynamic_update_slice.impl("hlo")(hlo_impls.dynamic_update_slice) + scatter_strided.impl("hlo")(hlo_impls.scatter_strided) + + # --- NN ops --- + from nkipy.core.ops.nn import topk + topk.impl("hlo")(hlo_impls.topk) + + # --- Collective ops --- + from nkipy.core.ops.collectives import ( + all_gather, all_reduce, reduce_scatter, all_to_all, + ) + all_gather.impl("hlo")(hlo_impls.all_gather) + all_reduce.impl("hlo")(hlo_impls.all_reduce) + reduce_scatter.impl("hlo")(hlo_impls.reduce_scatter) + all_to_all.impl("hlo")(hlo_impls.all_to_all) + + # --- Conv ops --- + from nkipy.core.ops.conv import conv2d, conv3d + conv2d.impl("hlo")(hlo_impls.conv2d) + conv3d.impl("hlo")(hlo_impls.conv3d) diff --git a/nkipy/src/nkipy/core/ops/_register_nkigen.py b/nkipy/src/nkipy/core/ops/_register_nkigen.py new file mode 100644 index 0000000..821f5a9 --- /dev/null +++ b/nkipy/src/nkipy/core/ops/_register_nkigen.py @@ -0,0 +1,118 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Register nkigen backend implementations for all ops. + +Called lazily the first time the nkigen backend is activated, so MLIR +imports only happen when needed. + +Composed ops (floor_divide, tan, rint, etc.) use ``composed_impl`` on the +Op itself and need no per-backend registration — they dispatch through +other ops that have nkigen primitives registered. +""" + +_registered = False + + +def register_all_nkigen_impls(): + global _registered + if _registered: + return + _registered = True + + from nkipy.core.ops import _nkigen_impls as nkigen_impls + + # --- Binary ops (primitives) --- + from nkipy.core.ops.binary import ( + add, subtract, multiply, divide, power, maximum, minimum, + equal, not_equal, greater, greater_equal, less, less_equal, + bitwise_and, bitwise_or, bitwise_xor, + ) + add.impl("nkigen")(nkigen_impls.add) + subtract.impl("nkigen")(nkigen_impls.subtract) + multiply.impl("nkigen")(nkigen_impls.multiply) + divide.impl("nkigen")(nkigen_impls.divide) + power.impl("nkigen")(nkigen_impls.power) + maximum.impl("nkigen")(nkigen_impls.maximum) + minimum.impl("nkigen")(nkigen_impls.minimum) + equal.impl("nkigen")(nkigen_impls.equal) + not_equal.impl("nkigen")(nkigen_impls.not_equal) + greater.impl("nkigen")(nkigen_impls.greater) + greater_equal.impl("nkigen")(nkigen_impls.greater_equal) + less.impl("nkigen")(nkigen_impls.less) + less_equal.impl("nkigen")(nkigen_impls.less_equal) + bitwise_and.impl("nkigen")(nkigen_impls.bitwise_and) + bitwise_or.impl("nkigen")(nkigen_impls.bitwise_or) + bitwise_xor.impl("nkigen")(nkigen_impls.bitwise_xor) + + # --- Unary ops (primitives) --- + from nkipy.core.ops.unary import ( + abs, exp, log, sqrt, sin, cos, tanh, ceil, floor, sign, + negative, reciprocal, square, logical_not, + ) + exp.impl("nkigen")(nkigen_impls.exp) + log.impl("nkigen")(nkigen_impls.log) + sqrt.impl("nkigen")(nkigen_impls.sqrt) + tanh.impl("nkigen")(nkigen_impls.tanh) + sin.impl("nkigen")(nkigen_impls.sin) + cos.impl("nkigen")(nkigen_impls.cos) + sign.impl("nkigen")(nkigen_impls.sign) + abs.impl("nkigen")(nkigen_impls.abs) + ceil.impl("nkigen")(nkigen_impls.ceil) + floor.impl("nkigen")(nkigen_impls.floor) + negative.impl("nkigen")(nkigen_impls.negative) + reciprocal.impl("nkigen")(nkigen_impls.reciprocal) + square.impl("nkigen")(nkigen_impls.square) + logical_not.impl("nkigen")(nkigen_impls.logical_not) + + # --- Linalg ops --- + from nkipy.core.ops.linalg import matmul + matmul.impl("nkigen")(nkigen_impls.matmul) + + # --- Reduction ops --- + from nkipy.core.ops.reduce import sum, prod, max, min, mean, std, var + sum.impl("nkigen")(nkigen_impls.reduce_sum) + prod.impl("nkigen")(nkigen_impls.reduce_prod) + max.impl("nkigen")(nkigen_impls.reduce_max) + min.impl("nkigen")(nkigen_impls.reduce_min) + mean.impl("nkigen")(nkigen_impls.reduce_mean) + std.impl("nkigen")(nkigen_impls.reduce_std) + var.impl("nkigen")(nkigen_impls.reduce_var) + + # --- Creation ops --- + from nkipy.core.ops.creation import ( + zeros as zeros_op, full as full_op, + zeros_like, ones_like, empty_like, full_like, + ) + zeros_op.impl("nkigen")(nkigen_impls.zeros) + full_op.impl("nkigen")(nkigen_impls.full) + zeros_like.impl("nkigen")(nkigen_impls.zeros_like) + ones_like.impl("nkigen")(nkigen_impls.ones_like) + empty_like.impl("nkigen")(nkigen_impls.empty_like) + full_like.impl("nkigen")(nkigen_impls.full_like) + + # --- Transform ops --- + from nkipy.core.ops.transform import ( + transpose, reshape, expand_dims, concatenate, + split, copy, broadcast_to, astype, squeeze, swapaxes, stack, + ) + transpose.impl("nkigen")(nkigen_impls.transpose) + reshape.impl("nkigen")(nkigen_impls.reshape) + expand_dims.impl("nkigen")(nkigen_impls.expand_dims) + concatenate.impl("nkigen")(nkigen_impls.concatenate) + split.impl("nkigen")(nkigen_impls.split) + copy.impl("nkigen")(nkigen_impls.copy) + broadcast_to.impl("nkigen")(nkigen_impls.broadcast_to) + astype.impl("nkigen")(nkigen_impls.astype) + squeeze.impl("nkigen")(nkigen_impls.squeeze) + swapaxes.impl("nkigen")(nkigen_impls.swapaxes) + stack.impl("nkigen")(nkigen_impls.stack) + + # --- Indexing ops --- + from nkipy.core.ops.indexing import ( + where as where_op, take as take_op, + static_slice, dynamic_update_slice, + ) + where_op.impl("nkigen")(nkigen_impls.where) + take_op.impl("nkigen")(nkigen_impls.take) + static_slice.impl("nkigen")(nkigen_impls.static_slice) + dynamic_update_slice.impl("nkigen")(nkigen_impls.dynamic_update_slice) diff --git a/nkipy/src/nkipy/core/ops/_registry.py b/nkipy/src/nkipy/core/ops/_registry.py index 4c0a2c2..653d284 100644 --- a/nkipy/src/nkipy/core/ops/_registry.py +++ b/nkipy/src/nkipy/core/ops/_registry.py @@ -6,7 +6,7 @@ appropriate backend implementation based on the current tracing state. """ -from typing import Callable, Dict +from typing import Callable, Dict, Optional class Op: @@ -28,11 +28,24 @@ def _zeros_cpu(shape, dtype): # Later, during tracing: result = zeros(shape, dtype) # Dispatches based on current backend + + Composed ops (built from other dispatched ops) use ``composed_impl``:: + + floor_divide = Op('floor_divide') + + @floor_divide.composed_impl + def _floor_divide(x, y): + return floor(divide(x, y)) + + The composed fallback is used when no backend-specific implementation + is registered. Since it calls other ops via dispatch, it works on any + backend that has the underlying primitives registered. """ def __init__(self, name: str): self.name = name self._impls: Dict[str, Callable] = {} + self._composed: Optional[Callable] = None def impl(self, backend: str) -> Callable: """Decorator to register a backend implementation. @@ -50,19 +63,37 @@ def decorator(fn: Callable) -> Callable: return decorator + @property + def composed_impl(self): + """Decorator to register a composed (backend-agnostic) fallback. + + A composed implementation is built entirely from calls to other + dispatched ops, so it works on any backend that has those primitives. + It is used as a fallback when no backend-specific impl is registered. + """ + + def decorator(fn: Callable) -> Callable: + self._composed = fn + return fn + + return decorator + def __call__(self, *args, **kwargs): """Dispatch to the appropriate backend implementation.""" from nkipy.core.backend import get_backend backend = get_backend() - if backend not in self._impls: - available = ", ".join(self._impls.keys()) if self._impls else "none" - raise NotImplementedError( - f"Operation '{self.name}' not implemented for backend '{backend}'." - f" Available backends: {available}" - ) - return self._impls[backend](*args, **kwargs) + if backend in self._impls: + return self._impls[backend](*args, **kwargs) + if self._composed is not None: + return self._composed(*args, **kwargs) + available = ", ".join(self._impls.keys()) if self._impls else "none" + raise NotImplementedError( + f"Operation '{self.name}' not implemented for backend '{backend}'." + f" Available backends: {available}" + ) def __repr__(self) -> str: backends = ", ".join(self._impls.keys()) if self._impls else "none" - return f"Op('{self.name}', backends=[{backends}])" + composed = ", composed" if self._composed else "" + return f"Op('{self.name}', backends=[{backends}]{composed})" diff --git a/nkipy/src/nkipy/core/ops/binary.py b/nkipy/src/nkipy/core/ops/binary.py index e486e47..55d75f6 100644 --- a/nkipy/src/nkipy/core/ops/binary.py +++ b/nkipy/src/nkipy/core/ops/binary.py @@ -2,330 +2,74 @@ # SPDX-License-Identifier: Apache-2.0 """Binary operations: add, subtract, multiply, divide, etc.""" -import numpy as np - -from nkipy.core.backend.hlo import ( - as_hlo_tensor, - broadcast_operands_hlo, - find_common_type_hlo, - get_hlo_context, -) from nkipy.core.ops._registry import Op -from nkipy.core.tensor import NKIPyTensorRef - -# ============================================================================= -# HLO Implementation Helpers -# ============================================================================= - - -def _build_binary_hlo(x, y, np_op, out=None, dtype=None): - """Build a binary HLO operation with broadcasting - - Args: - x: First operand - y: Second operand - np_op: NumPy operation - out: Unused (for API compatibility with IR) - dtype: Optional output dtype to cast result to - """ - ctx = get_hlo_context() - - promoted_dtype = find_common_type_hlo(x, y) - - # If dtype parameter is provided, use it as the output dtype - if dtype is not None: - output_dtype = np.dtype(dtype) - else: - output_dtype = promoted_dtype - - # Convert to HLOTensor - x = ( - x.backend_tensor - if isinstance(x, NKIPyTensorRef) - else as_hlo_tensor(ctx, x, promoted_dtype) - ) - y = ( - y.backend_tensor - if isinstance(y, NKIPyTensorRef) - else as_hlo_tensor(ctx, y, promoted_dtype) - ) - - # Map numpy ops to HLO opcodes - op_map = { - np.add: "add", - np.subtract: "subtract", - np.multiply: "multiply", - np.divide: "divide", - np.power: "power", - np.maximum: "maximum", - np.minimum: "minimum", - np.bitwise_and: "and", - np.bitwise_or: "or", - np.bitwise_xor: "xor", - np.logical_and: "and", - np.logical_or: "or", - np.logical_xor: "xor", - } - - hlo_op = op_map.get( - np_op, np_op.__name__ if hasattr(np_op, "__name__") else str(np_op) - ) - - # Broadcast operands to compatible shapes - x_broadcast, y_broadcast = broadcast_operands_hlo(ctx, x, y) - - # Explicit type promotion to match IR behavior - # This ensures operands have the correct dtype before the operation - if output_dtype != x_broadcast.dtype: - x_broadcast = ctx.build_op( - "convert", [x_broadcast], x_broadcast.shape, output_dtype - ) - if output_dtype != y_broadcast.dtype: - y_broadcast = ctx.build_op( - "convert", [y_broadcast], y_broadcast.shape, output_dtype - ) - - result_tensor = ctx.build_op( - hlo_op, [x_broadcast, y_broadcast], x_broadcast.shape, output_dtype - ) - - return NKIPyTensorRef(result_tensor) - - -def _build_comparison_hlo(x, y, np_op, out=None, dtype=None): - """Build a comparison HLO operation""" - ctx = get_hlo_context() - - promoted_dtype = find_common_type_hlo(x, y) - - # Convert to HLOTensor - x = ( - x.backend_tensor - if isinstance(x, NKIPyTensorRef) - else as_hlo_tensor(ctx, x, promoted_dtype) - ) - y = ( - y.backend_tensor - if isinstance(y, NKIPyTensorRef) - else as_hlo_tensor(ctx, y, promoted_dtype) - ) - - # Type promotion: convert both tensors to the promoted dtype if needed - if x.dtype != promoted_dtype: - x = ctx.build_op("convert", [x], x.shape, promoted_dtype) - if y.dtype != promoted_dtype: - y = ctx.build_op("convert", [y], y.shape, promoted_dtype) - - # Map numpy ops to HLO comparison directions - comp_map = { - np.equal: "EQ", - np.not_equal: "NE", - np.less: "LT", - np.less_equal: "LE", - np.greater: "GT", - np.greater_equal: "GE", - } - - comp_dir = comp_map.get(np_op, "EQ") - - # Broadcast operands to compatible shapes - x_broadcast, y_broadcast = broadcast_operands_hlo(ctx, x, y) - - result_tensor = ctx.build_op( - "compare", - [x_broadcast, y_broadcast], - x_broadcast.shape, - np.bool_, - {"comparison_direction": comp_dir}, - ) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# Factory function for simple binary ops -# ----------------------------------------------------------------------------- -def _make_binary_op(name: str, np_op) -> Op: - """Create a binary Op with IR and HLO implementations.""" - op = Op(name) - - @op.impl("hlo") - def _impl_hlo(x, y, out=None, dtype=None): - return _build_binary_hlo(x, y, np_op, out=out, dtype=dtype) - - return op - - -def _make_comparison_op(name: str, np_op) -> Op: - """Create a comparison Op with IR and HLO implementations.""" - op = Op(name) - - @op.impl("hlo") - def _impl_hlo(x, y, out=None, dtype=None): - return _build_comparison_hlo(x, y, np_op, out=out, dtype=dtype) - - return op - - -def _build_logical_hlo(x, y, hlo_op_name, out=None, dtype=None): - """Build a logical HLO operation. - - Logical operations first convert inputs to boolean (non-zero check), - then apply the logical operation. - - Args: - x: First operand - y: Second operand - hlo_op_name: HLO operation name ('and', 'or', 'xor') - out: Unused (for API compatibility) - dtype: Unused (for API compatibility) - """ - ctx = get_hlo_context() - - # Find common type for scalars - promoted_dtype = find_common_type_hlo(x, y) - - # Convert to HLOTensor - x = ( - x.backend_tensor - if isinstance(x, NKIPyTensorRef) - else as_hlo_tensor(ctx, x, promoted_dtype) - ) - y = ( - y.backend_tensor - if isinstance(y, NKIPyTensorRef) - else as_hlo_tensor(ctx, y, promoted_dtype) - ) - - # Broadcast operands to compatible shapes - x_broadcast, y_broadcast = broadcast_operands_hlo(ctx, x, y) - - # Convert to boolean: x != 0 - zero_x = as_hlo_tensor(ctx, 0, x_broadcast.dtype) - if x_broadcast.shape: - zero_x = ctx.build_op( - "broadcast", - [zero_x], - x_broadcast.shape, - x_broadcast.dtype, - {"broadcast_dimensions": []}, - ) - x_bool = ctx.build_op( - "compare", - [x_broadcast, zero_x], - x_broadcast.shape, - np.bool_, - {"comparison_direction": "NE"}, - ) - - zero_y = as_hlo_tensor(ctx, 0, y_broadcast.dtype) - if y_broadcast.shape: - zero_y = ctx.build_op( - "broadcast", - [zero_y], - y_broadcast.shape, - y_broadcast.dtype, - {"broadcast_dimensions": []}, - ) - y_bool = ctx.build_op( - "compare", - [y_broadcast, zero_y], - y_broadcast.shape, - np.bool_, - {"comparison_direction": "NE"}, - ) - - # Apply logical operation on boolean values - result_tensor = ctx.build_op( - hlo_op_name, [x_bool, y_bool], x_broadcast.shape, np.bool_ - ) - - return NKIPyTensorRef(result_tensor) - - -def _make_logical_op(name: str, hlo_op_name: str) -> Op: - """Create a logical Op with HLO implementation. - - Logical operations convert inputs to boolean first, then apply the operation. - """ - op = Op(name) - - @op.impl("hlo") - def _impl_hlo(x, y, out=None, dtype=None): - return _build_logical_hlo(x, y, hlo_op_name, out=out, dtype=dtype) - - return op - # ----------------------------------------------------------------------------- # Arithmetic operations # ----------------------------------------------------------------------------- -add = _make_binary_op("add", np.add) -subtract = _make_binary_op("subtract", np.subtract) -multiply = _make_binary_op("multiply", np.multiply) -divide = _make_binary_op("divide", np.divide) -power = _make_binary_op("power", np.power) -maximum = _make_binary_op("maximum", np.maximum) -minimum = _make_binary_op("minimum", np.minimum) +add = Op("add") +subtract = Op("subtract") +multiply = Op("multiply") +divide = Op("divide") +power = Op("power") +maximum = Op("maximum") +minimum = Op("minimum") # ----------------------------------------------------------------------------- # Bitwise operations # ----------------------------------------------------------------------------- -bitwise_and = _make_binary_op("bitwise_and", np.bitwise_and) -bitwise_or = _make_binary_op("bitwise_or", np.bitwise_or) -bitwise_xor = _make_binary_op("bitwise_xor", np.bitwise_xor) +bitwise_and = Op("bitwise_and") +bitwise_or = Op("bitwise_or") +bitwise_xor = Op("bitwise_xor") # ----------------------------------------------------------------------------- # Comparison operations # ----------------------------------------------------------------------------- -equal = _make_comparison_op("equal", np.equal) -not_equal = _make_comparison_op("not_equal", np.not_equal) -greater = _make_comparison_op("greater", np.greater) -greater_equal = _make_comparison_op("greater_equal", np.greater_equal) -less = _make_comparison_op("less", np.less) -less_equal = _make_comparison_op("less_equal", np.less_equal) +equal = Op("equal") +not_equal = Op("not_equal") +greater = Op("greater") +greater_equal = Op("greater_equal") +less = Op("less") +less_equal = Op("less_equal") # ----------------------------------------------------------------------------- # Logical operations # ----------------------------------------------------------------------------- -logical_and = _make_logical_op("logical_and", "and") -logical_or = _make_logical_op("logical_or", "or") -logical_xor = _make_logical_op("logical_xor", "xor") +logical_and = Op("logical_and") +logical_or = Op("logical_or") +logical_xor = Op("logical_xor") # ----------------------------------------------------------------------------- # Composed binary operations # ----------------------------------------------------------------------------- -# logaddexp: log(exp(x) + exp(y)), numerically stable via max trick logaddexp = Op("logaddexp") -@logaddexp.impl("hlo") -def _logaddexp_hlo(x, y): +@logaddexp.composed_impl +def _logaddexp(x, y): from nkipy.core.ops.unary import exp, log m = maximum(x, y) return add(m, log(add(exp(subtract(x, m)), exp(subtract(y, m))))) -# remainder: a - b * floor(a / b) remainder = Op("remainder") -@remainder.impl("hlo") -def _remainder_hlo(x, y): +@remainder.composed_impl +def _remainder(x, y): from nkipy.core.ops.unary import floor return subtract(x, multiply(y, floor(divide(x, y)))) -# floor_divide: floor(a / b) floor_divide = Op("floor_divide") -@floor_divide.impl("hlo") -def _floor_divide_hlo(x, y): +@floor_divide.composed_impl +def _floor_divide(x, y): from nkipy.core.ops.unary import floor return floor(divide(x, y)) diff --git a/nkipy/src/nkipy/core/ops/collectives.py b/nkipy/src/nkipy/core/ops/collectives.py index ef04d2d..f181e30 100644 --- a/nkipy/src/nkipy/core/ops/collectives.py +++ b/nkipy/src/nkipy/core/ops/collectives.py @@ -4,9 +4,7 @@ import numpy as np -from nkipy.core.backend.hlo import get_hlo_context from nkipy.core.ops._registry import Op -from nkipy.core.tensor import NKIPyTensorRef # ----------------------------------------------------------------------------- # all_gather @@ -16,14 +14,7 @@ @all_gather.impl("cpu") def _all_gather_cpu(data, all_gather_dim, replica_groups, **kwargs): - """CPU implementation of all_gather with duplicated data assumption. - - In CPU execution, we assume all ranks have identical data. - all_gather collects data from all ranks and concatenates along the gather dimension. - - Shape: dim[all_gather_dim] *= world_size - Data: Replicate along gather dimension (same data from all ranks) - """ + """CPU implementation of all_gather with duplicated data assumption.""" world_size = len(replica_groups[0]) tile_reps = tuple( world_size if i == all_gather_dim else 1 for i in range(data.ndim) @@ -31,28 +22,6 @@ def _all_gather_cpu(data, all_gather_dim, replica_groups, **kwargs): return np.tile(data, tile_reps) -@all_gather.impl("hlo") -def _all_gather_hlo(data, all_gather_dim, replica_groups, **kwargs): - ctx = get_hlo_context() - - rank = len(replica_groups[0]) - out_shape = list(data.shape) - if out_shape: - out_shape[all_gather_dim] *= rank - - result_tensor = ctx.build_op( - "all-gather", - [data.backend_tensor], - tuple(out_shape), - data.dtype, - { - "all_gather_dim": all_gather_dim, - "replica_groups": replica_groups, - }, - ) - return NKIPyTensorRef(result_tensor) - - # ----------------------------------------------------------------------------- # all_reduce # ----------------------------------------------------------------------------- @@ -61,16 +30,7 @@ def _all_gather_hlo(data, all_gather_dim, replica_groups, **kwargs): @all_reduce.impl("cpu") def _all_reduce_cpu(data, replica_groups, reduce_op=np.add, **kwargs): - """CPU implementation of all_reduce with duplicated data assumption. - - In CPU execution, we assume all ranks have identical data. - all_reduce applies the reduction operation across all ranks. - - Shape: unchanged - Data: For add, result = data * world_size (sum of identical values) - For max/min, result = data (max/min of identical values) - For multiply, result = data ** world_size - """ + """CPU implementation of all_reduce with duplicated data assumption.""" world_size = len(replica_groups[0]) if reduce_op == np.add: @@ -80,36 +40,9 @@ def _all_reduce_cpu(data, replica_groups, reduce_op=np.add, **kwargs): elif reduce_op == np.multiply: return data**world_size else: - # Default: return copy return data.copy() -@all_reduce.impl("hlo") -def _all_reduce_hlo(data, replica_groups, reduce_op=np.add, **kwargs): - ctx = get_hlo_context() - - # Map reduce_op to string for HLO - reduce_op_map = { - np.add: "add", - np.multiply: "multiply", - np.maximum: "maximum", - np.minimum: "minimum", - } - reduce_op_str = reduce_op_map.get(reduce_op, "add") - - result_tensor = ctx.build_op( - "all-reduce", - [data.backend_tensor], - data.shape, - data.dtype, - { - "replica_groups": replica_groups, - "reduce_op": reduce_op_str, - }, - ) - return NKIPyTensorRef(result_tensor) - - # ----------------------------------------------------------------------------- # reduce_scatter # ----------------------------------------------------------------------------- @@ -120,17 +53,9 @@ def _all_reduce_hlo(data, replica_groups, reduce_op=np.add, **kwargs): def _reduce_scatter_cpu( data, reduce_scatter_dim: int, replica_groups, reduce_op=np.add, **kwargs ): - """CPU implementation of reduce_scatter with duplicated data assumption. - - In CPU execution, we assume all ranks have identical data. - reduce_scatter first reduces across ranks, then scatters the result. - - Shape: dim[reduce_scatter_dim] //= world_size - Data: First reduce (as in all_reduce), then take 1/world_size slice (as rank 0) - """ + """CPU implementation of reduce_scatter with duplicated data assumption.""" world_size = len(replica_groups[0]) - # First apply reduction (as in all_reduce) if reduce_op == np.add: reduced = data * world_size elif reduce_op == np.maximum or reduce_op == np.minimum: @@ -140,44 +65,10 @@ def _reduce_scatter_cpu( else: reduced = data.copy() - # Then scatter: take the first chunk (simulating rank 0) chunk_size = data.shape[reduce_scatter_dim] // world_size return np.take(reduced, range(chunk_size), axis=reduce_scatter_dim) -@reduce_scatter.impl("hlo") -def _reduce_scatter_hlo( - data, reduce_scatter_dim: int, replica_groups, reduce_op=np.add, **kwargs -): - ctx = get_hlo_context() - rank = len(replica_groups[0]) - out_shape = list(data.shape) - if out_shape: - out_shape[reduce_scatter_dim] //= rank - - # Map reduce_op to string for HLO - reduce_op_map = { - np.add: "add", - np.multiply: "multiply", - np.maximum: "maximum", - np.minimum: "minimum", - } - reduce_op_str = reduce_op_map.get(reduce_op, "add") - - result_tensor = ctx.build_op( - "reduce-scatter", - [data.backend_tensor], - tuple(out_shape), - data.dtype, - { - "reduce_scatter_dim": reduce_scatter_dim, - "replica_groups": replica_groups, - "reduce_op": reduce_op_str, - }, - ) - return NKIPyTensorRef(result_tensor) - - # ----------------------------------------------------------------------------- # all_to_all # ----------------------------------------------------------------------------- @@ -188,46 +79,10 @@ def _reduce_scatter_hlo( def _all_to_all_cpu( data, split_dimension: int, concat_dimension: int, replica_groups, **kwargs ): - """CPU implementation of all_to_all with duplicated data assumption. - - In CPU execution, we assume all ranks have identical data. - all_to_all splits data along split_dimension and redistributes along - concat_dimension. - - Shape: unchanged (splits and concats balance out) - Data: With duplicated data, effectively rearranges chunks between dimensions. - Since all ranks have the same data, the result is equivalent to: - - Split into world_size chunks along split_dimension - - Concatenate along concat_dimension - """ + """CPU implementation of all_to_all with duplicated data assumption.""" world_size = len(replica_groups[0]) - # Split along split_dimension chunks = np.split(data, world_size, axis=split_dimension) - - # Concatenate along concat_dimension - # With duplicated data, each rank would send its chunk to all others - # and receive chunks from all others. Since data is identical, - # we just rearrange the chunks. result = np.concatenate(chunks, axis=concat_dimension) return result - - -@all_to_all.impl("hlo") -def _all_to_all_hlo( - data, split_dimension: int, concat_dimension: int, replica_groups, **kwargs -): - ctx = get_hlo_context() - result_tensor = ctx.build_op( - "all-to-all", - [data.backend_tensor], - data.shape, - data.dtype, - { - "split_dimension": split_dimension, - "concat_dimension": concat_dimension, - "replica_groups": replica_groups, - }, - ) - return NKIPyTensorRef(result_tensor) diff --git a/nkipy/src/nkipy/core/ops/conv.py b/nkipy/src/nkipy/core/ops/conv.py index 5a4057c..91ac523 100644 --- a/nkipy/src/nkipy/core/ops/conv.py +++ b/nkipy/src/nkipy/core/ops/conv.py @@ -34,36 +34,23 @@ def _normalize_tuple_3d(value, name): def _im2col_2d(input_padded, kernel_h, kernel_w, stride_h, stride_w, out_h, out_w): - """Convert input to column matrix for efficient convolution via matmul. - - Args: - input_padded: Padded input tensor of shape (batch, in_channels, height, width) - kernel_h, kernel_w: Kernel dimensions - stride_h, stride_w: Stride values - out_h, out_w: Output spatial dimensions - - Returns: - Column matrix of shape (batch, in_channels * kernel_h * kernel_w, out_h * out_w) - """ + """Convert input to column matrix for efficient convolution via matmul.""" batch_size, in_channels, _, _ = input_padded.shape - # Use stride_tricks to create a view of sliding windows - # Shape: (batch, in_channels, out_h, out_w, kernel_h, kernel_w) shape = (batch_size, in_channels, out_h, out_w, kernel_h, kernel_w) strides = ( - input_padded.strides[0], # batch - input_padded.strides[1], # channel - input_padded.strides[2] * stride_h, # out_h (strided) - input_padded.strides[3] * stride_w, # out_w (strided) - input_padded.strides[2], # kernel_h - input_padded.strides[3], # kernel_w + input_padded.strides[0], + input_padded.strides[1], + input_padded.strides[2] * stride_h, + input_padded.strides[3] * stride_w, + input_padded.strides[2], + input_padded.strides[3], ) windows = np.lib.stride_tricks.as_strided( input_padded, shape=shape, strides=strides ) - # Reshape to (batch, in_channels * kernel_h * kernel_w, out_h * out_w) col = windows.transpose(0, 1, 4, 5, 2, 3).reshape( batch_size, in_channels * kernel_h * kernel_w, out_h * out_w ) @@ -83,40 +70,25 @@ def _im2col_3d( out_h, out_w, ): - """Convert input to column matrix for efficient 3D convolution via matmul. - - Args: - input_padded: Padded input tensor of shape - (batch, in_channels, depth, height, width) - kernel_d, kernel_h, kernel_w: Kernel dimensions - stride_d, stride_h, stride_w: Stride values - out_d, out_h, out_w: Output spatial dimensions - - Returns: - Column matrix of shape - (batch, in_channels * kernel_d * kernel_h * kernel_w, out_d * out_h * out_w) - """ + """Convert input to column matrix for efficient 3D convolution via matmul.""" batch_size, in_channels, _, _, _ = input_padded.shape - # Use stride_tricks to create a view of sliding windows - # Shape: (batch, in_channels, out_d, out_h, out_w, kernel_d, kernel_h, kernel_w) shape = (batch_size, in_channels, out_d, out_h, out_w, kernel_d, kernel_h, kernel_w) strides = ( - input_padded.strides[0], # batch - input_padded.strides[1], # channel - input_padded.strides[2] * stride_d, # out_d (strided) - input_padded.strides[3] * stride_h, # out_h (strided) - input_padded.strides[4] * stride_w, # out_w (strided) - input_padded.strides[2], # kernel_d - input_padded.strides[3], # kernel_h - input_padded.strides[4], # kernel_w + input_padded.strides[0], + input_padded.strides[1], + input_padded.strides[2] * stride_d, + input_padded.strides[3] * stride_h, + input_padded.strides[4] * stride_w, + input_padded.strides[2], + input_padded.strides[3], + input_padded.strides[4], ) windows = np.lib.stride_tricks.as_strided( input_padded, shape=shape, strides=strides ) - # (batch, in_channels * kernel_d * kernel_h * kernel_w, out_d * out_h * out_w) col = windows.transpose(0, 1, 5, 6, 7, 2, 3, 4).reshape( batch_size, in_channels * kernel_d * kernel_h * kernel_w, out_d * out_h * out_w ) @@ -142,28 +114,11 @@ def _conv2d_cpu( out=None, dtype=None, ): - """2D Convolution operation (CPU) using im2col approach. - - Args: - input: Input tensor of shape (batch, in_channels, height, width) - weight: Weight tensor of shape (out_channels, in_channels, kernel_h, kernel_w) - bias: Optional bias tensor of shape (out_channels,) - stride: Stride for convolution (int or tuple) - padding: Padding for input (int or tuple) - dilation: Dilation for kernel (int or tuple) - NOT IMPLEMENTED - groups: Number of groups for grouped convolution - NOT IMPLEMENTED - out: Output tensor (unused) - dtype: Output dtype (unused) - - Returns: - Output tensor of shape (batch, out_channels, out_height, out_width) - """ - # Normalize parameters + """2D Convolution operation (CPU) using im2col approach.""" stride = _normalize_tuple_2d(stride, "stride") padding = _normalize_tuple_2d(padding, "padding") dilation = _normalize_tuple_2d(dilation, "dilation") - # Check for unsupported features if dilation != (1, 1): raise NotImplementedError( f"conv2d CPU backend does not support dilation != 1, " @@ -177,15 +132,12 @@ def _conv2d_cpu( stride_h, stride_w = stride pad_h, pad_w = padding - # Get dimensions batch_size, in_channels, in_height, in_width = input.shape out_channels, _, kernel_h, kernel_w = weight.shape - # Calculate output dimensions out_height = (in_height + 2 * pad_h - kernel_h) // stride_h + 1 out_width = (in_width + 2 * pad_w - kernel_w) // stride_w + 1 - # Pad input if necessary if pad_h > 0 or pad_w > 0: input_padded = np.pad( input, @@ -196,120 +148,23 @@ def _conv2d_cpu( else: input_padded = np.ascontiguousarray(input) - # Ensure input is contiguous for stride_tricks if not input_padded.flags["C_CONTIGUOUS"]: input_padded = np.ascontiguousarray(input_padded) - # im2col: convert input to column matrix - # Shape: (batch, in_channels * kernel_h * kernel_w, out_height * out_width) col = _im2col_2d( input_padded, kernel_h, kernel_w, stride_h, stride_w, out_height, out_width ) - # Reshape weight for matrix multiplication - # Shape: (out_channels, in_channels * kernel_h * kernel_w) weight_reshaped = weight.reshape(out_channels, -1) - - # Perform convolution as matrix multiplication - # weight_reshaped: (out_channels, in_channels * kernel_h * kernel_w) - # col: (batch, in_channels * kernel_h * kernel_w, out_h * out_w) - # We want: (batch, out_channels, out_h * out_w) - # Use matmul: weight_reshaped @ col[b] for each batch output = np.matmul(weight_reshaped, col) - - # Reshape to output shape output = output.reshape(batch_size, out_channels, out_height, out_width) - # Add bias if provided if bias is not None: output = output + bias.reshape(1, -1, 1, 1) return output -@conv2d.impl("hlo") -def _conv2d_hlo( - input, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - groups=1, - out=None, - dtype=None, -): - """2D Convolution operation for HLO.""" - from nkipy.core.backend.hlo import broadcast_to_shape_hlo, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(input, NKIPyTensorRef): - input = input.backend_tensor - if isinstance(weight, NKIPyTensorRef): - weight = weight.backend_tensor - - stride = _normalize_tuple_2d(stride, "stride") - dilation = _normalize_tuple_2d(dilation, "dilation") - padding_tuple = _normalize_tuple_2d(padding, "padding") - - batch_size, in_channels, in_height, in_width = input.shape - out_channels, _, kernel_height, kernel_width = weight.shape - - out_height = ( - in_height + 2 * padding_tuple[0] - dilation[0] * (kernel_height - 1) - 1 - ) // stride[0] + 1 - out_width = ( - in_width + 2 * padding_tuple[1] - dilation[1] * (kernel_width - 1) - 1 - ) // stride[1] + 1 - - output_shape = (batch_size, out_channels, out_height, out_width) - - padding_config = [ - (padding_tuple[0], padding_tuple[0]), - (padding_tuple[1], padding_tuple[1]), - ] - - result_tensor = ctx.build_op( - "convolution", - [input, weight], - output_shape, - input.dtype, - { - "window_strides": list(stride), - "padding": padding_config, - "lhs_dilation": [1, 1], - "rhs_dilation": list(dilation), - "feature_group_count": groups, - "batch_group_count": 1, - "input_batch_dimension": 0, - "input_feature_dimension": 1, - "input_spatial_dimensions": [2, 3], - "kernel_output_feature_dimension": 0, - "kernel_input_feature_dimension": 1, - "kernel_spatial_dimensions": [2, 3], - "output_batch_dimension": 0, - "output_feature_dimension": 1, - "output_spatial_dimensions": [2, 3], - }, - ) - - if bias is not None: - if isinstance(bias, NKIPyTensorRef): - bias = bias.backend_tensor - - bias_reshaped = ctx.build_op( - "reshape", [bias], (1, out_channels, 1, 1), bias.dtype - ) - bias_broadcast = broadcast_to_shape_hlo(ctx, bias_reshaped, output_shape) - result_tensor = ctx.build_op( - "add", [result_tensor, bias_broadcast], output_shape, input.dtype - ) - - return NKIPyTensorRef(result_tensor) - - # ----------------------------------------------------------------------------- # conv3d # ----------------------------------------------------------------------------- @@ -328,29 +183,11 @@ def _conv3d_cpu( out=None, dtype=None, ): - """3D Convolution operation (CPU) using im2col approach. - - Args: - input: Input tensor of shape (batch, in_channels, depth, height, width) - weight: Weight tensor of shape (out_channels, in_channels, kernel_d, - kernel_h, kernel_w) - bias: Optional bias tensor of shape (out_channels,) - stride: Stride for convolution (int or tuple) - padding: Padding for input (int or tuple) - dilation: Dilation for kernel (int or tuple) - NOT IMPLEMENTED - groups: Number of groups for grouped convolution - NOT IMPLEMENTED - out: Output tensor (unused) - dtype: Output dtype (unused) - - Returns: - Output tensor of shape (batch, out_channels, out_depth, out_height, out_width) - """ - # Normalize parameters + """3D Convolution operation (CPU) using im2col approach.""" stride = _normalize_tuple_3d(stride, "stride") padding = _normalize_tuple_3d(padding, "padding") dilation = _normalize_tuple_3d(dilation, "dilation") - # Check for unsupported features if dilation != (1, 1, 1): raise NotImplementedError( f"conv3d CPU backend does not support dilation != 1, " @@ -364,16 +201,13 @@ def _conv3d_cpu( stride_d, stride_h, stride_w = stride pad_d, pad_h, pad_w = padding - # Get dimensions batch_size, in_channels, in_depth, in_height, in_width = input.shape out_channels, _, kernel_d, kernel_h, kernel_w = weight.shape - # Calculate output dimensions out_depth = (in_depth + 2 * pad_d - kernel_d) // stride_d + 1 out_height = (in_height + 2 * pad_h - kernel_h) // stride_h + 1 out_width = (in_width + 2 * pad_w - kernel_w) // stride_w + 1 - # Pad input if necessary if pad_d > 0 or pad_h > 0 or pad_w > 0: input_padded = np.pad( input, @@ -384,13 +218,9 @@ def _conv3d_cpu( else: input_padded = np.ascontiguousarray(input) - # Ensure input is contiguous for stride_tricks if not input_padded.flags["C_CONTIGUOUS"]: input_padded = np.ascontiguousarray(input_padded) - # im2col: convert input to column matrix - # (batch, in_channels * kernel_d * kernel_h * kernel_w, - # out_depth * out_height * out_width) col = _im2col_3d( input_padded, kernel_d, @@ -404,108 +234,11 @@ def _conv3d_cpu( out_width, ) - # Reshape weight for matrix multiplication - # Shape: (out_channels, in_channels * kernel_d * kernel_h * kernel_w) weight_reshaped = weight.reshape(out_channels, -1) - - # Perform convolution as matrix multiplication - # weight_reshaped: (out_channels, in_channels * kernel_d * kernel_h * kernel_w) - # col: (batch, in_channels * kernel_d * kernel_h * kernel_w, out_d * out_h * out_w) - # We want: (batch, out_channels, out_d * out_h * out_w) output = np.matmul(weight_reshaped, col) - - # Reshape to output shape output = output.reshape(batch_size, out_channels, out_depth, out_height, out_width) - # Add bias if provided if bias is not None: output = output + bias.reshape(1, -1, 1, 1, 1) return output - - -@conv3d.impl("hlo") -def _conv3d_hlo( - input, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - groups=1, - out=None, - dtype=None, -): - """3D Convolution operation for HLO.""" - from nkipy.core.backend.hlo import broadcast_to_shape_hlo, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(input, NKIPyTensorRef): - input = input.backend_tensor - if isinstance(weight, NKIPyTensorRef): - weight = weight.backend_tensor - - stride = _normalize_tuple_3d(stride, "stride") - dilation = _normalize_tuple_3d(dilation, "dilation") - padding_tuple = _normalize_tuple_3d(padding, "padding") - - batch_size, in_channels, in_depth, in_height, in_width = input.shape - out_channels, _, kernel_depth, kernel_height, kernel_width = weight.shape - - out_depth = ( - in_depth + 2 * padding_tuple[0] - dilation[0] * (kernel_depth - 1) - 1 - ) // stride[0] + 1 - out_height = ( - in_height + 2 * padding_tuple[1] - dilation[1] * (kernel_height - 1) - 1 - ) // stride[1] + 1 - out_width = ( - in_width + 2 * padding_tuple[2] - dilation[2] * (kernel_width - 1) - 1 - ) // stride[2] + 1 - - output_shape = (batch_size, out_channels, out_depth, out_height, out_width) - - padding_config = [ - (padding_tuple[0], padding_tuple[0]), - (padding_tuple[1], padding_tuple[1]), - (padding_tuple[2], padding_tuple[2]), - ] - - result_tensor = ctx.build_op( - "convolution", - [input, weight], - output_shape, - input.dtype, - { - "window_strides": list(stride), - "padding": padding_config, - "lhs_dilation": [1, 1, 1], - "rhs_dilation": list(dilation), - "feature_group_count": groups, - "batch_group_count": 1, - "input_batch_dimension": 0, - "input_feature_dimension": 1, - "input_spatial_dimensions": [2, 3, 4], - "kernel_output_feature_dimension": 0, - "kernel_input_feature_dimension": 1, - "kernel_spatial_dimensions": [2, 3, 4], - "output_batch_dimension": 0, - "output_feature_dimension": 1, - "output_spatial_dimensions": [2, 3, 4], - }, - ) - - if bias is not None: - if isinstance(bias, NKIPyTensorRef): - bias = bias.backend_tensor - - bias_reshaped = ctx.build_op( - "reshape", [bias], (1, out_channels, 1, 1, 1), bias.dtype - ) - bias_broadcast = broadcast_to_shape_hlo(ctx, bias_reshaped, output_shape) - result_tensor = ctx.build_op( - "add", [result_tensor, bias_broadcast], output_shape, input.dtype - ) - - return NKIPyTensorRef(result_tensor) diff --git a/nkipy/src/nkipy/core/ops/creation.py b/nkipy/src/nkipy/core/ops/creation.py index de50d52..5690150 100644 --- a/nkipy/src/nkipy/core/ops/creation.py +++ b/nkipy/src/nkipy/core/ops/creation.py @@ -2,607 +2,60 @@ # SPDX-License-Identifier: Apache-2.0 """Array creation operations: zeros, full, zeros_like, empty_like, full_like, ones_like""" -import builtins - import numpy as np from nkipy.core.ops._registry import Op -builtins_min = builtins.min - # ----------------------------------------------------------------------------- -# zeros +# Primitive creation ops # ----------------------------------------------------------------------------- zeros = Op("zeros") +full = Op("full") +constant = Op("constant") +zeros_like = Op("zeros_like") +ones_like = Op("ones_like") +empty_like = Op("empty_like") +full_like = Op("full_like") +tril = Op("tril") +triu = Op("triu") +diag = Op("diag") + + +# ----------------------------------------------------------------------------- +# CPU implementations (needed for non-tracing execution) +# ----------------------------------------------------------------------------- @zeros.impl("cpu") def _zeros_cpu(shape, dtype): - """Create a tensor filled with zeros (CPU).""" return np.zeros(shape, dtype=dtype) -@zeros.impl("hlo") -def _zeros_hlo(shape, dtype): - """Create a tensor filled with zeros (HLO).""" - from nkipy.core.backend.hlo import as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - # Normalize shape to tuple - if isinstance(shape, int): - shape = (shape,) - - # Create a scalar zero constant - zero_tensor = as_hlo_tensor(ctx, 0.0, dtype) - - # Broadcast to the target shape - if shape: - result_tensor = ctx.build_op( - "broadcast", [zero_tensor], shape, dtype, {"broadcast_dimensions": []} - ) - else: - result_tensor = zero_tensor - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# zeros_like -# ----------------------------------------------------------------------------- -zeros_like = Op("zeros_like") - - @zeros_like.impl("cpu") def _zeros_like_cpu(x, dtype=None): - """Create a tensor of zeros with the same shape as x (CPU).""" return np.zeros_like(x, dtype=dtype) -@zeros_like.impl("hlo") -def _zeros_like_hlo(x, dtype=None): - """Create a tensor of zeros with the same shape as x (HLO). - - Note: We need to reference the input tensor x in the computation graph - even though we don't use its values, to ensure the HLO module parameter - count matches the computation. - """ - from nkipy.core.backend.hlo import as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_hlo = x.backend_tensor - else: - x_hlo = x - - result_dtype = dtype if dtype is not None else x_hlo.dtype - - # Create a scalar zero constant - zero_tensor = as_hlo_tensor(ctx, 0.0, result_dtype) - - # Broadcast to the target shape - if x_hlo.shape: - result_tensor = ctx.build_op( - "broadcast", - [zero_tensor], - x_hlo.shape, - result_dtype, - {"broadcast_dimensions": []}, - ) - else: - result_tensor = zero_tensor - - # FIXME: Workaround to ensure x is referenced in the computation graph. - # HLO requires all module parameters to be used in the computation - zero_multiplier = as_hlo_tensor(ctx, 0.0, x_hlo.dtype) - if x_hlo.shape: - zero_multiplier = ctx.build_op( - "broadcast", - [zero_multiplier], - x_hlo.shape, - x_hlo.dtype, - {"broadcast_dimensions": []}, - ) - - x_times_zero = ctx.build_op( - "multiply", [x_hlo, zero_multiplier], x_hlo.shape, x_hlo.dtype - ) - - if x_hlo.dtype != result_dtype: - x_times_zero = ctx.build_op( - "convert", [x_times_zero], x_hlo.shape, result_dtype - ) - - result_tensor = ctx.build_op( - "add", [result_tensor, x_times_zero], x_hlo.shape, result_dtype - ) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# ones_like -# ----------------------------------------------------------------------------- -ones_like = Op("ones_like") - - @ones_like.impl("cpu") def _ones_like_cpu(x, dtype=None): - """Create a tensor of ones with the same shape as x (CPU).""" return np.ones_like(x, dtype=dtype) -@ones_like.impl("hlo") -def _ones_like_hlo(x, dtype=None): - """Create a tensor of ones with the same shape as x (HLO).""" - from nkipy.core.backend.hlo import as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_hlo = x.backend_tensor - else: - x_hlo = x - - result_dtype = dtype if dtype is not None else x_hlo.dtype - - # Create a scalar one constant - one_tensor = as_hlo_tensor(ctx, 1.0, result_dtype) - - # Broadcast to the target shape - if x_hlo.shape: - result_tensor = ctx.build_op( - "broadcast", - [one_tensor], - x_hlo.shape, - result_dtype, - {"broadcast_dimensions": []}, - ) - else: - result_tensor = one_tensor - - # FIXME: Workaround to ensure x is referenced in the computation graph - zero_multiplier = as_hlo_tensor(ctx, 0.0, x_hlo.dtype) - if x_hlo.shape: - zero_multiplier = ctx.build_op( - "broadcast", - [zero_multiplier], - x_hlo.shape, - x_hlo.dtype, - {"broadcast_dimensions": []}, - ) - - x_times_zero = ctx.build_op( - "multiply", [x_hlo, zero_multiplier], x_hlo.shape, x_hlo.dtype - ) - - if x_hlo.dtype != result_dtype: - x_times_zero = ctx.build_op( - "convert", [x_times_zero], x_hlo.shape, result_dtype - ) - - result_tensor = ctx.build_op( - "add", [result_tensor, x_times_zero], x_hlo.shape, result_dtype - ) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# empty_like -# ----------------------------------------------------------------------------- -empty_like = Op("empty_like") - - @empty_like.impl("cpu") def _empty_like_cpu(x, dtype=None): - """Create an uninitialized tensor with the same shape as x (CPU).""" return np.empty_like(x, dtype=dtype) -@empty_like.impl("hlo") -def _empty_like_hlo(x, dtype=None): - """Create an uninitialized tensor with the same shape as x (HLO). - - For HLO, same as zeros_like since we can't have uninitialized memory. - """ - return _zeros_like_hlo(x, dtype=dtype) - - -# ----------------------------------------------------------------------------- -# full_like -# ----------------------------------------------------------------------------- -full_like = Op("full_like") - - @full_like.impl("cpu") def _full_like_cpu(x, fill_value, dtype=None): - """Create a tensor filled with fill_value with the same shape as x (CPU).""" return np.full_like(x, fill_value, dtype=dtype) -@full_like.impl("hlo") -def _full_like_hlo(x, fill_value, dtype=None): - """Create a tensor filled with fill_value with the same shape as x (HLO).""" - from nkipy.core.backend.hlo import as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_hlo = x.backend_tensor - else: - x_hlo = x - - result_dtype = dtype if dtype is not None else x_hlo.dtype - - # Create a scalar constant with the fill value - fill_tensor = as_hlo_tensor(ctx, fill_value, result_dtype) - - # Broadcast to the target shape - if x_hlo.shape: - result_tensor = ctx.build_op( - "broadcast", - [fill_tensor], - x_hlo.shape, - result_dtype, - {"broadcast_dimensions": []}, - ) - else: - result_tensor = fill_tensor - - # FIXME: Workaround to ensure x is referenced in the computation graph - zero_multiplier = as_hlo_tensor(ctx, 0.0, x_hlo.dtype) - if x_hlo.shape: - zero_multiplier = ctx.build_op( - "broadcast", - [zero_multiplier], - x_hlo.shape, - x_hlo.dtype, - {"broadcast_dimensions": []}, - ) - - x_times_zero = ctx.build_op( - "multiply", [x_hlo, zero_multiplier], x_hlo.shape, x_hlo.dtype - ) - - if x_hlo.dtype != result_dtype: - x_times_zero = ctx.build_op( - "convert", [x_times_zero], x_hlo.shape, result_dtype - ) - - result_tensor = ctx.build_op( - "add", [result_tensor, x_times_zero], x_hlo.shape, result_dtype - ) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# full -# ----------------------------------------------------------------------------- -full = Op("full") - - @full.impl("cpu") def _full_cpu(shape, fill_value, dtype): - """Create a tensor filled with a constant value (CPU).""" return np.full(shape, fill_value, dtype=dtype) -@full.impl("hlo") -def _full_hlo(shape, fill_value, dtype): - """Create a tensor filled with a constant value (HLO).""" - from nkipy.core.backend.hlo import as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - # Normalize shape to tuple - if isinstance(shape, int): - shape = (shape,) - - # Create a scalar constant with the fill value - fill_tensor = as_hlo_tensor(ctx, fill_value, dtype) - - # Broadcast to the target shape - if shape: - result_tensor = ctx.build_op( - "broadcast", [fill_tensor], shape, dtype, {"broadcast_dimensions": []} - ) - else: - result_tensor = fill_tensor - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# constant - promote compile-time numpy arrays to runtime tensors -# ----------------------------------------------------------------------------- -constant = Op("constant") - - @constant.impl("cpu") def _constant_cpu(value, dtype=None): - """Convert value to numpy array (CPU). - - When not tracing, this simply ensures the value is an ndarray. - """ return np.asarray(value, dtype=dtype) - - -@constant.impl("hlo") -def _constant_hlo(value, dtype=None): - """Promote a numpy array or scalar to an HLO constant tensor. - - This is the primary API for promoting compile-time numpy arrays - to runtime tensors during HLO tracing. It wraps the value as a - single HLO constant op. - - Behavior: - - NKIPyTensorRef: pass-through (idempotent), with optional dtype cast - - np.ndarray: create HLO constant - - scalar (int, float, bool): create scalar HLO constant - - list/tuple: convert to np.ndarray first, then create HLO constant - """ - from nkipy.core.backend.hlo import as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - # Idempotent: if already a traced tensor, just handle dtype - if isinstance(value, NKIPyTensorRef): - if dtype is not None and value.dtype != np.dtype(dtype): - from nkipy.core.ops.transform import astype - - return astype(value, dtype) - return value - - ctx = get_hlo_context() - - # Determine target dtype - if dtype is not None: - target_dtype = np.dtype(dtype) - elif hasattr(value, "dtype"): - target_dtype = np.dtype(value.dtype) - elif isinstance(value, float): - target_dtype = np.dtype(np.float32) - elif isinstance(value, int): - target_dtype = np.dtype(np.int32) - elif isinstance(value, bool): - target_dtype = np.dtype(np.bool_) - else: - target_dtype = np.dtype(np.asarray(value).dtype) - - # Convert lists/tuples to ndarray - if isinstance(value, (list, tuple)): - value = np.asarray(value, dtype=target_dtype) - - hlo_tensor = as_hlo_tensor(ctx, value, target_dtype) - return NKIPyTensorRef(hlo_tensor) - - -# ----------------------------------------------------------------------------- -# tril - lower triangle of an array -# ----------------------------------------------------------------------------- -tril = Op("tril") - - -@tril.impl("hlo") -def _tril_hlo(x, k=0): - """Lower triangle: where(row >= col - k, x, 0).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.ops.binary import greater_equal, subtract - from nkipy.core.ops.indexing import where - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_bt = x.backend_tensor - else: - x_bt = x - - shape = x_bt.shape - ndim = len(shape) - - # Create iota for row indices (second-to-last dim) - row_iota = ctx.build_op( - "iota", [], shape, np.dtype(np.int32), {"iota_dimension": ndim - 2} - ) - row_ref = NKIPyTensorRef(row_iota) - - # Create iota for col indices (last dim) - col_iota = ctx.build_op( - "iota", [], shape, np.dtype(np.int32), {"iota_dimension": ndim - 1} - ) - col_ref = NKIPyTensorRef(col_iota) - - # Mask: row >= col - k → row + k >= col → row - (col - k) >= 0 - if k != 0: - col_ref = subtract(col_ref, k) - mask = greater_equal(row_ref, col_ref) - - return where(mask, x, 0.0) - - -# ----------------------------------------------------------------------------- -# triu - upper triangle of an array -# ----------------------------------------------------------------------------- -triu = Op("triu") - - -@triu.impl("hlo") -def _triu_hlo(x, k=0): - """Upper triangle: where(row <= col - k, x, 0).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.ops.binary import less_equal, subtract - from nkipy.core.ops.indexing import where - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_bt = x.backend_tensor - else: - x_bt = x - - shape = x_bt.shape - ndim = len(shape) - - # Create iota for row indices (second-to-last dim) - row_iota = ctx.build_op( - "iota", [], shape, np.dtype(np.int32), {"iota_dimension": ndim - 2} - ) - row_ref = NKIPyTensorRef(row_iota) - - # Create iota for col indices (last dim) - col_iota = ctx.build_op( - "iota", [], shape, np.dtype(np.int32), {"iota_dimension": ndim - 1} - ) - col_ref = NKIPyTensorRef(col_iota) - - # Mask: row <= col - k - if k != 0: - col_ref = subtract(col_ref, k) - mask = less_equal(row_ref, col_ref) - - return where(mask, x, 0.0) - - -# ----------------------------------------------------------------------------- -# diag - extract diagonal or construct diagonal matrix -# ----------------------------------------------------------------------------- -diag = Op("diag") - - -@diag.impl("hlo") -def _diag_hlo(v, k=0): - """Extract diagonal from 2D → 1D, or create diagonal matrix from 1D → 2D.""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.ops.binary import equal, subtract - from nkipy.core.ops.indexing import where - from nkipy.core.ops.reduce import sum - from nkipy.core.ops.transform import broadcast_to, reshape - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(v, NKIPyTensorRef): - v_bt = v.backend_tensor - else: - v_bt = v - - ndim = len(v_bt.shape) - - if ndim == 1: - # 1D → 2D: construct diagonal matrix - n = v_bt.shape[0] + abs(k) - shape_2d = (n, n) - - # Create row and col iota - row_iota = ctx.build_op( - "iota", [], shape_2d, np.dtype(np.int32), {"iota_dimension": 0} - ) - row_ref = NKIPyTensorRef(row_iota) - col_iota = ctx.build_op( - "iota", [], shape_2d, np.dtype(np.int32), {"iota_dimension": 1} - ) - col_ref = NKIPyTensorRef(col_iota) - - if k != 0: - col_ref = subtract(col_ref, k) - - diag_mask = equal(row_ref, col_ref) - - # Broadcast v to (n, n): put v along the diagonal - # v has length v_bt.shape[0], we need to gather it at the right positions - # Use iota as gather index: for each (i, j) on the diagonal, take v[i] (or v[j-k]) - if k >= 0: - # Diagonal starts at column k: row index gives position in v - idx_ref = row_ref - else: - # Diagonal starts at row -k: raw col index gives position in v - idx_ref = NKIPyTensorRef(col_iota) - - # Clamp indices to valid range — out-of-bounds positions are zeroed by the mask - from nkipy.core.ops.unary import clip - - idx_ref = clip(idx_ref, 0, v_bt.shape[0] - 1) - - # Use take to gather v values at idx positions, then mask - from nkipy.core.ops.indexing import take - - v_gathered = take(v, idx_ref, axis=0) - return where(diag_mask, v_gathered, 0.0) - - elif ndim == 2: - # 2D → 1D: extract diagonal - rows, cols = v_bt.shape - if k >= 0: - diag_len = builtins_min(rows, cols - k) - else: - diag_len = builtins_min(rows + k, cols) - - if diag_len <= 0: - # Return empty-ish result — use shape (0,) but HLO needs >= 1 - # Just return a size-1 zero as a fallback - return zeros((0,), v_bt.dtype) - - # Create indices for the diagonal - diag_indices = np.arange(diag_len, dtype=np.int32) - if k >= 0: - row_indices = diag_indices - col_indices = diag_indices + k - else: - row_indices = diag_indices - k - col_indices = diag_indices - - # Create the 2D mask approach: iota == expected diagonal position - shape_2d = v_bt.shape - row_iota = ctx.build_op( - "iota", [], shape_2d, np.dtype(np.int32), {"iota_dimension": 0} - ) - row_ref = NKIPyTensorRef(row_iota) - col_iota = ctx.build_op( - "iota", [], shape_2d, np.dtype(np.int32), {"iota_dimension": 1} - ) - col_ref = NKIPyTensorRef(col_iota) - - if k != 0: - col_ref = subtract(col_ref, k) - - diag_mask = equal(row_ref, col_ref) - - # Mask and sum along cols to extract diagonal - masked = where(diag_mask, v, 0.0) - - # Sum along the appropriate axis to collapse to 1D - if k >= 0: - # Sum along columns, then slice to diag_len - result = sum(masked, axis=1) - else: - # Sum along rows, then slice to diag_len - result = sum(masked, axis=0) - - # Slice to the correct diagonal length - from nkipy.core.ops.transform import astype as astype_op - - result_shape = result.shape - if result_shape[0] != diag_len: - from nkipy.core.ops.indexing import static_slice - - result = static_slice( - result, - [0], - [diag_len], - [1], - [], - ) - - return result - - else: - raise ValueError(f"Input must be 1-D or 2-D, got {ndim}-D") diff --git a/nkipy/src/nkipy/core/ops/indexing.py b/nkipy/src/nkipy/core/ops/indexing.py index 4f7c5ff..9a5246b 100644 --- a/nkipy/src/nkipy/core/ops/indexing.py +++ b/nkipy/src/nkipy/core/ops/indexing.py @@ -4,859 +4,16 @@ static_slice, dynamic_update_slice, scatter_strided """ -import itertools -from typing import List, Tuple - -import numpy as np - from nkipy.core.ops._registry import Op # ----------------------------------------------------------------------------- -# where +# Primitive indexing ops # ----------------------------------------------------------------------------- where = Op("where") - - -@where.impl("hlo") -def _where_hlo(condition, x, y): - """Select elements from x or y based on condition (HLO).""" - from nkipy.core.backend.hlo import ( - HLOOp, - as_hlo_tensor, - broadcast_to_shape_hlo, - find_common_type_hlo, - get_hlo_context, - ) - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - output_dtype = find_common_type_hlo(x, y) - - # Convert inputs to HLO tensors - if isinstance(condition, NKIPyTensorRef): - condition = condition.backend_tensor - elif np.isscalar(condition): - condition = as_hlo_tensor(ctx, bool(condition), np.bool_) - elif isinstance(condition, np.ndarray): - condition = as_hlo_tensor(ctx, condition.astype(bool), np.bool_) - - # If other integer type, convert to bool first - if hasattr(condition, "dtype") and condition.dtype != np.bool_: - zero = as_hlo_tensor(ctx, 0, condition.dtype) - if condition.shape: - zero = ctx.build_op( - "broadcast", - [zero], - condition.shape, - condition.dtype, - {"broadcast_dimensions": []}, - ) - condition = ctx.build_op( - "compare", - [condition, zero], - condition.shape, - np.bool_, - {"comparison_direction": "NE"}, - ) - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - elif np.isscalar(x): - # Use output_dtype for scalar conversion to avoid losing precision - x = as_hlo_tensor(ctx, x, output_dtype) - elif isinstance(x, np.ndarray): - const_op = HLOOp( - "constant", - [], - result_shape=x.shape, - result_dtype=x.dtype, - attributes={"value": x}, - ) - x = ctx.module.add_operation(const_op) - - if isinstance(y, NKIPyTensorRef): - y = y.backend_tensor - elif np.isscalar(y): - # Use output_dtype for scalar conversion to avoid losing precision - y = as_hlo_tensor(ctx, y, output_dtype) - elif isinstance(y, np.ndarray): - const_op = HLOOp( - "constant", - [], - result_shape=y.shape, - result_dtype=y.dtype, - attributes={"value": y}, - ) - y = ctx.module.add_operation(const_op) - - # Broadcast all three inputs to a common shape - broadcast_shape = tuple(np.broadcast_shapes(condition.shape, x.shape, y.shape)) - - if condition.shape != broadcast_shape: - condition = broadcast_to_shape_hlo(ctx, condition, broadcast_shape) - - if x.shape != broadcast_shape: - x = broadcast_to_shape_hlo(ctx, x, broadcast_shape) - - if y.shape != broadcast_shape: - y = broadcast_to_shape_hlo(ctx, y, broadcast_shape) - - result_tensor = ctx.build_op( - "select", [condition, x, y], broadcast_shape, output_dtype - ) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# take -# ----------------------------------------------------------------------------- take = Op("take") - - -@take.impl("hlo") -def _take_hlo(x, indices, axis=None): - """Take elements from an array along an axis (HLO).""" - from nkipy.core.backend.hlo import HLOOp, as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # Handle axis=None case - flatten the array first - if axis is None: - flattened_shape = (int(np.prod(x.shape)),) - x = ctx.build_op("reshape", [x], flattened_shape, x.dtype) - axis = 0 - - # Normalize negative axis - if axis < 0: - axis = len(x.shape) + axis - - dtype = x.dtype - - # Convert indices to HLO tensor - if isinstance(indices, NKIPyTensorRef): - indices_tensor = indices.backend_tensor - elif np.isscalar(indices): - if indices < 0: - indices = x.shape[axis] + indices - indices_tensor = as_hlo_tensor(ctx, int(indices), np.dtype(np.int32)) - elif isinstance(indices, (np.ndarray, list)): - if isinstance(indices, list): - indices_np = np.array(indices, dtype=np.int32) - else: - indices_np = indices.astype(np.int32) - const_op = HLOOp( - "constant", - [], - result_shape=indices_np.shape, - result_dtype=np.dtype(np.int32), - attributes={"value": indices_np}, - ) - indices_tensor = ctx.module.add_operation(const_op) - else: - raise ValueError( - "np.take only supports TensorRef, scalar, np.ndarray, or list as indices!" - ) - - indices_shape = indices_tensor.shape if hasattr(indices_tensor, "shape") else () - - # Build output shape - output_shape = [] - for i in range(len(x.shape)): - if i == axis: - output_shape.extend(indices_shape) - else: - output_shape.append(x.shape[i]) - output_shape = tuple(output_shape) - - # Configure gather dimension numbers - offset_dims = [] - for i in range(len(output_shape)): - if i < axis or i >= axis + len(indices_shape): - offset_dims.append(i) - - collapsed_slice_dims = [axis] - start_index_map = [axis] - index_vector_dim = len(indices_shape) - - slice_sizes = list(x.shape) - slice_sizes[axis] = 1 - - result_tensor = ctx.build_op( - "gather", - [x, indices_tensor], - output_shape, - dtype, - { - "offset_dims": offset_dims, - "collapsed_slice_dims": collapsed_slice_dims, - "start_index_map": start_index_map, - "index_vector_dim": index_vector_dim, - "slice_sizes": slice_sizes, - "indices_are_sorted": False, - }, - ) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# take_along_axis -# ----------------------------------------------------------------------------- take_along_axis = Op("take_along_axis") - - -@take_along_axis.impl("hlo") -def _take_along_axis_hlo(x, indices, axis): - """Take values from the input array by matching 1d index and data slices (HLO).""" - from nkipy.core.backend.hlo import ( - HLOOp, - broadcast_to_shape_hlo, - get_hlo_context, - ) - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # Handle axis=None case - if axis is None: - flattened_shape = (int(np.prod(x.shape)),) - x = ctx.build_op("reshape", [x], flattened_shape, x.dtype) - axis = 0 - - # Normalize negative axis - if axis < 0: - axis = len(x.shape) + axis - - # Convert indices to HLO tensor - if isinstance(indices, NKIPyTensorRef): - indices_tensor = indices.backend_tensor - elif isinstance(indices, np.ndarray): - indices_np = indices.astype(np.int32) - const_op = HLOOp( - "constant", - [], - result_shape=indices_np.shape, - result_dtype=np.dtype(np.int32), - attributes={"value": indices_np}, - ) - indices_tensor = ctx.module.add_operation(const_op) - else: - raise ValueError( - "take_along_axis only supports TensorRef or np.ndarray as indices!" - ) - - data_rank = len(x.shape) - - # Convert indices to int32 if needed - if indices_tensor.dtype != np.dtype(np.int32): - indices_tensor = ctx.build_op( - "convert", [indices_tensor], indices_tensor.shape, np.dtype(np.int32) - ) - - # Broadcast indices to match operand shape - target_indices_shape = list(x.shape) - target_indices_shape[axis] = ( - indices_tensor.shape[axis] if axis < len(indices_tensor.shape) else 1 - ) - target_indices_shape = tuple(target_indices_shape) - - if indices_tensor.shape != target_indices_shape: - indices_tensor = broadcast_to_shape_hlo( - ctx, indices_tensor, target_indices_shape - ) - - # Create index arrays for all dimensions - index_arrays = [] - for i in range(data_rank): - if i == axis: - index_arrays.append(indices_tensor) - else: - arange_shape = [1] * data_rank - arange_shape[i] = x.shape[i] - arange_shape = tuple(arange_shape) - - arange_vals = np.arange(x.shape[i], dtype=np.int32) - const_op = HLOOp( - "constant", - [], - result_shape=(x.shape[i],), - result_dtype=np.dtype(np.int32), - attributes={"value": arange_vals}, - ) - arange_tensor = ctx.module.add_operation(const_op) - - arange_tensor = ctx.build_op( - "reshape", [arange_tensor], arange_shape, np.dtype(np.int32) - ) - index_arrays.append(arange_tensor) - - # Broadcast all index arrays to the same shape - broadcast_shape = target_indices_shape - broadcasted_indices = [] - for idx_array in index_arrays: - if idx_array.shape != broadcast_shape: - broadcasted = broadcast_to_shape_hlo(ctx, idx_array, broadcast_shape) - broadcasted_indices.append(broadcasted) - else: - broadcasted_indices.append(idx_array) - - # Stack indices along last dimension - reshaped_indices = [] - for idx in broadcasted_indices: - new_shape = idx.shape + (1,) - reshaped = ctx.build_op("reshape", [idx], new_shape, np.dtype(np.int32)) - reshaped_indices.append(reshaped) - - stacked_shape = broadcast_shape + (data_rank,) - gather_indices = ctx.build_op( - "concatenate", - reshaped_indices, - stacked_shape, - np.dtype(np.int32), - {"dimension": data_rank}, - ) - - # Configure gather dimension numbers - dnums = { - "offset_dims": [], - "collapsed_slice_dims": list(range(data_rank)), - "start_index_map": list(range(data_rank)), - "index_vector_dim": data_rank, - "slice_sizes": [1] * data_rank, - "indices_are_sorted": False, - } - - result_tensor = ctx.build_op( - "gather", - [x, gather_indices], - broadcast_shape, - x.dtype, - dnums, - ) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# scatter_along_axis (internal) - Window-level scatter with 1D indices -# ----------------------------------------------------------------------------- -# Used by _do_scatter_indexing (__setitem__ with tensor indices). -# NOT the same as np.put_along_axis: this scatters entire rows/columns -# (1D indices, update_window_dims covers non-axis dims), while numpy's -# put_along_axis does element-level scatter (indices same ndim as array). scatter_along_axis = Op("scatter_along_axis") - - -@scatter_along_axis.impl("hlo") -def _scatter_along_axis_hlo(x, indices, values, axis): - """Scatter whole slices along an axis using 1D indices. - - For a 2D array with axis=1 and indices [2, 0]: - a[:, 2] = values[:, 0] - a[:, 0] = values[:, 1] - """ - from nkipy.core.backend.hlo import HLOOp, as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - x_copy = ctx.build_op("copy", [x], x.shape, x.dtype) - - if axis < 0: - axis = len(x.shape) + axis - - # Convert indices to HLO tensor - if isinstance(indices, NKIPyTensorRef): - indices_tensor = indices.backend_tensor - elif isinstance(indices, np.ndarray): - indices_np = indices.astype(np.int32) - const_op = HLOOp( - "constant", - [], - result_shape=indices_np.shape, - result_dtype=np.dtype(np.int32), - attributes={"value": indices_np}, - ) - indices_tensor = ctx.module.add_operation(const_op) - else: - raise ValueError("scatter_along_axis requires TensorRef or np.ndarray indices") - - # Convert values to HLO tensor - if isinstance(values, NKIPyTensorRef): - values_tensor = values.backend_tensor - elif isinstance(values, np.ndarray): - values_np = values.astype(x.dtype) - const_op = HLOOp( - "constant", - [], - result_shape=values_np.shape, - result_dtype=x.dtype, - attributes={"value": values_np}, - ) - values_tensor = ctx.module.add_operation(const_op) - else: - values_tensor = as_hlo_tensor(ctx, values, x.dtype) - - # Window-level scatter: update_window_dims covers all non-axis dims - update_window_dims = [i for i in range(len(x_copy.shape)) if i != axis] - scattered_tensor = ctx.build_op( - "scatter", - [x_copy, indices_tensor, values_tensor], - x_copy.shape, - x.dtype, - { - "update_window_dims": update_window_dims, - "inserted_window_dims": [axis], - "scatter_dims_to_operand_dims": [axis], - "index_vector_dim": len(indices_tensor.shape), - "update_computation": "assign", - "indices_are_sorted": False, - "unique_indices": False, - }, - ) - - return NKIPyTensorRef(scattered_tensor) - - -# ----------------------------------------------------------------------------- -# put_along_axis - numpy-compatible element-level scatter -# ----------------------------------------------------------------------------- put_along_axis = Op("put_along_axis") - - -@put_along_axis.impl("hlo") -def _put_along_axis_hlo(x, indices, values, axis): - """Element-level scatter matching np.put_along_axis semantics. - - For each position (i, j, ...) in indices: - arr[..., indices[i,j,...], ...] = values[i,j,...] - where the index value replaces the coordinate along ``axis``. - - Lowered to a flat 1D scatter for Neuron compiler compatibility. - """ - from nkipy.core.backend.hlo import HLOOp, as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - x_shape = x.shape - x_dtype = x.dtype - x_copy = ctx.build_op("copy", [x], x_shape, x_dtype) - - # Normalize axis - if axis is None: - axis = 0 - effective_shape = (int(np.prod(x_shape)),) - else: - if axis < 0: - axis = len(x_shape) + axis - effective_shape = x_shape - - # --- Convert indices to HLO tensor --- - if isinstance(indices, NKIPyTensorRef): - indices_tensor = indices.backend_tensor - elif isinstance(indices, np.ndarray): - indices_np = indices.astype(np.int32) - const_op = HLOOp( - "constant", - [], - result_shape=indices_np.shape, - result_dtype=np.dtype(np.int32), - attributes={"value": indices_np}, - ) - indices_tensor = ctx.module.add_operation(const_op) - else: - raise ValueError( - "put_along_axis only supports TensorRef or np.ndarray as indices!" - ) - - # Ensure int32 for arithmetic (indices may be uint32 from user code) - if indices_tensor.dtype != np.dtype(np.int32): - indices_tensor = ctx.build_op( - "convert", [indices_tensor], indices_tensor.shape, np.dtype(np.int32) - ) - - idx_shape = indices_tensor.shape - - # --- Convert values to HLO tensor --- - if np.isscalar(values): - scalar_tensor = as_hlo_tensor(ctx, values, x_dtype) - if idx_shape: - values_tensor = ctx.build_op( - "broadcast", - [scalar_tensor], - idx_shape, - x_dtype, - {"broadcast_dimensions": []}, - ) - else: - values_tensor = scalar_tensor - elif isinstance(values, NKIPyTensorRef): - values_tensor = values.backend_tensor - elif isinstance(values, np.ndarray): - values_np = values.astype(x_dtype) - const_op = HLOOp( - "constant", - [], - result_shape=values_np.shape, - result_dtype=x_dtype, - attributes={"value": values_np}, - ) - values_tensor = ctx.module.add_operation(const_op) - else: - raise ValueError( - "put_along_axis only supports scalar, TensorRef, or np.ndarray as values!" - ) - - if values_tensor.shape != idx_shape: - values_tensor = ctx.build_op( - "reshape", [values_tensor], idx_shape, values_tensor.dtype - ) - - # --- Compute flat 1D scatter indices --- - # Row-major strides of the effective shape. - ndim = len(effective_shape) - strides = [1] * ndim - for d in range(ndim - 2, -1, -1): - strides[d] = strides[d + 1] * effective_shape[d + 1] - - # Static offset: for each position in idx_shape, the flat contribution - # from all non-axis dimensions (known at trace time). - offset_np = np.zeros(idx_shape, dtype=np.int32) - for d in range(ndim): - if d == axis: - continue - coord = np.arange(idx_shape[d], dtype=np.int32) - bcast = [1] * len(idx_shape) - bcast[d] = idx_shape[d] - offset_np = offset_np + coord.reshape(bcast) * strides[d] - - offset_const = ctx.build_op( - "constant", [], idx_shape, np.dtype(np.int32), {"value": offset_np} - ) - - # flat_indices = indices * stride[axis] + offset - axis_stride_scalar = ctx.build_op( - "constant", - [], - (), - np.dtype(np.int32), - {"value": np.int32(strides[axis])}, - ) - axis_stride = ctx.build_op( - "broadcast", - [axis_stride_scalar], - idx_shape, - np.dtype(np.int32), - {"broadcast_dimensions": []}, - ) - scaled = ctx.build_op( - "multiply", - [indices_tensor, axis_stride], - idx_shape, - np.dtype(np.int32), - ) - flat_indices = ctx.build_op( - "add", - [scaled, offset_const], - idx_shape, - np.dtype(np.int32), - ) - - # --- Flatten and scatter --- - flat_size = int(np.prod(effective_shape)) - num_elements = int(np.prod(idx_shape)) - - x_flat = ctx.build_op("reshape", [x_copy], (flat_size,), x_dtype) - flat_indices_1d = ctx.build_op( - "reshape", - [flat_indices], - (num_elements,), - np.dtype(np.int32), - ) - flat_values_1d = ctx.build_op( - "reshape", - [values_tensor], - (num_elements,), - x_dtype, - ) - - scattered = ctx.build_op( - "scatter", - [x_flat, flat_indices_1d, flat_values_1d], - (flat_size,), - x_dtype, - { - "update_window_dims": [], - "inserted_window_dims": [0], - "scatter_dims_to_operand_dims": [0], - "index_vector_dim": 1, - "update_computation": "assign", - "indices_are_sorted": False, - "unique_indices": False, - }, - ) - - # Reshape back to original x shape - result_tensor = ctx.build_op("reshape", [scattered], x_shape, x_dtype) - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# static_slice - Backend-agnostic static slicing operation -# ----------------------------------------------------------------------------- static_slice = Op("static_slice") - - -@static_slice.impl("hlo") -def _static_slice_hlo( - x, - start_indices: List[int], - limit_indices: List[int], - strides: List[int], - squeeze_dims: List[int], -): - """Static slicing using HLO slice operation. - - Args: - x: Input tensor (NKIPyTensorRef) - start_indices: Start index for each dimension - limit_indices: End index (exclusive) for each dimension - strides: Step size for each dimension - squeeze_dims: Dimensions to squeeze (remove) from output - - Returns: - Sliced tensor - """ - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - backend_tensor = x.backend_tensor - dtype = x.dtype - else: - backend_tensor = x - dtype = x.dtype - - # Calculate slice shape (before squeeze) - slice_shape = [] - for start, limit, stride in zip(start_indices, limit_indices, strides): - size = (limit - start + stride - 1) // stride - slice_shape.append(size) - - # Build HLO slice operation - result_tensor = ctx.build_op( - "slice", - [backend_tensor], - tuple(slice_shape), - dtype, - { - "start_indices": start_indices, - "limit_indices": limit_indices, - "strides": strides, - }, - ) - - # If we need to squeeze dimensions, reshape - if squeeze_dims: - output_shape = [s for i, s in enumerate(slice_shape) if i not in squeeze_dims] - final_shape = tuple(output_shape) if output_shape else () - result_tensor = ctx.build_op("reshape", [result_tensor], final_shape, dtype) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# dynamic_update_slice - Backend-agnostic contiguous slice assignment -# ----------------------------------------------------------------------------- dynamic_update_slice = Op("dynamic_update_slice") - - -@dynamic_update_slice.impl("hlo") -def _dynamic_update_slice_hlo( - x, - value, - start_indices: List[int], - update_shape: Tuple[int, ...], -): - """Contiguous slice assignment using HLO dynamic-update-slice. - - Args: - x: Input tensor to update (NKIPyTensorRef) - value: Value tensor to insert (NKIPyTensorRef or scalar/array) - start_indices: Start index for each dimension - update_shape: Expected shape of the update region - - Returns: - Updated tensor (new tensor, not in-place) - """ - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_tensor = x.backend_tensor - x_shape = x.shape - x_dtype = x.dtype - else: - x_tensor = x - x_shape = x.shape - x_dtype = x.dtype - - # Convert value to tensor if needed - if isinstance(value, NKIPyTensorRef): - value_tensor = value.backend_tensor - elif isinstance(value, (int, float)): - # Scalar value - create constant and broadcast - value_array = np.full(update_shape, value, dtype=x_dtype) - value_tensor = ctx.build_op( - "constant", [], tuple(update_shape), x_dtype, {"value": value_array} - ) - elif isinstance(value, np.ndarray): - value_tensor = ctx.build_op( - "constant", [], value.shape, value.dtype, {"value": value} - ) - else: - value_tensor = value - - # Reshape value if needed to match update_shape - if value_tensor.shape != update_shape: - value_tensor = ctx.build_op( - "reshape", [value_tensor], update_shape, value_tensor.dtype - ) - - # Create scalar constant tensors for start indices - start_index_tensors = [] - for start_idx in start_indices: - scalar_tensor = ctx.build_op("constant", [], (), np.int32, {"value": start_idx}) - start_index_tensors.append(scalar_tensor) - - # Build dynamic-update-slice operation - result_tensor = ctx.build_op( - "dynamic-update-slice", - [x_tensor, value_tensor] + start_index_tensors, - x_shape, - x_dtype, - {}, - ) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# scatter_strided - Backend-agnostic strided slice assignment -# ----------------------------------------------------------------------------- scatter_strided = Op("scatter_strided") - - -@scatter_strided.impl("hlo") -def _scatter_strided_hlo( - x, - value, - scatter_indices_per_dim: List[List[int]], -): - """Strided slice assignment using HLO scatter. - - For a[::2, ::2] = b, scatter values to strided positions. - - Args: - x: Input tensor to update (NKIPyTensorRef) - value: Value tensor to scatter (NKIPyTensorRef or scalar/array) - scatter_indices_per_dim: List of index lists for each dimension - - Returns: - Updated tensor (new tensor, not in-place) - """ - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_tensor = x.backend_tensor - x_shape = x.shape - x_dtype = x.dtype - else: - x_tensor = x - x_shape = x.shape - x_dtype = x.dtype - - # Calculate expected value shape from scatter indices - value_shape = tuple(len(indices) for indices in scatter_indices_per_dim) - - # Convert value to tensor if needed - if isinstance(value, NKIPyTensorRef): - value_tensor = value.backend_tensor - elif isinstance(value, (int, float)): - value_array = np.full(value_shape, value, dtype=x_dtype) - value_tensor = ctx.build_op( - "constant", [], value_shape, x_dtype, {"value": value_array} - ) - elif isinstance(value, np.ndarray): - value_tensor = ctx.build_op( - "constant", [], value.shape, value.dtype, {"value": value} - ) - else: - value_tensor = value - - # Create meshgrid of all scatter positions - all_positions = list(itertools.product(*scatter_indices_per_dim)) - scatter_indices_array = np.array(all_positions, dtype=np.int32) - - # Create HLO constant for scatter indices - indices_tensor = ctx.build_op( - "constant", - [], - scatter_indices_array.shape, - np.dtype(np.int32), - {"value": scatter_indices_array}, - ) - - # Flatten the value tensor to match the number of scatter positions - flat_value_shape = (scatter_indices_array.shape[0],) - flat_value = ctx.build_op( - "reshape", [value_tensor], flat_value_shape, value_tensor.dtype - ) - - # Configure scatter parameters - update_window_dims = [] - inserted_window_dims = list(range(len(x_shape))) - scatter_dims_to_operand_dims = list(range(len(x_shape))) - index_vector_dim = 1 - - # Build scatter operation - result_tensor = ctx.build_op( - "scatter", - [x_tensor, indices_tensor, flat_value], - x_shape, - x_dtype, - { - "update_window_dims": update_window_dims, - "inserted_window_dims": inserted_window_dims, - "scatter_dims_to_operand_dims": scatter_dims_to_operand_dims, - "index_vector_dim": index_vector_dim, - "update_computation": "assign", - "indices_are_sorted": False, - "unique_indices": True, - }, - ) - - return NKIPyTensorRef(result_tensor) diff --git a/nkipy/src/nkipy/core/ops/linalg.py b/nkipy/src/nkipy/core/ops/linalg.py index 89866d9..492ca34 100644 --- a/nkipy/src/nkipy/core/ops/linalg.py +++ b/nkipy/src/nkipy/core/ops/linalg.py @@ -2,209 +2,24 @@ # SPDX-License-Identifier: Apache-2.0 """Linear algebra operations: matmul, dot""" -import numpy as np - from nkipy.core.ops._registry import Op # ----------------------------------------------------------------------------- -# matmul +# Primitive linalg ops # ----------------------------------------------------------------------------- matmul = Op("matmul") - - -@matmul.impl("hlo") -def _matmul_hlo(x, y, out=None, dtype=None): - """Matrix multiplication (HLO). - - Supports batched matrix multiplication with broadcasting. - Follows numpy semantics for 1D inputs: a 1D left operand is promoted to - 2D by prepending a 1, a 1D right operand by appending a 1, and the extra - dimension is removed from the result after the dot. - """ - from nkipy.core.backend.hlo import ( - broadcast_to_shape_hlo, - find_common_type_hlo, - get_hlo_context, - ) - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - result_dtype = find_common_type_hlo(x, y) - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - if isinstance(y, NKIPyTensorRef): - y = y.backend_tensor - - # Matmul requires at least 1D arrays - assert len(x.shape) >= 1 and len(y.shape) >= 1, "matmul requires at least 1D arrays" - - # Handle 1D inputs by promoting to 2D (numpy matmul semantics). - # The added dimension is stripped from the result after the dot. - squeeze_lhs = False - squeeze_rhs = False - - if len(x.shape) == 1 and len(y.shape) == 1: - # Vector dot product: contract dimension 0 of both - assert x.shape[0] == y.shape[0], "Incompatible shapes for dot product" - result_shape = () - lhs_contracting_dims = [0] - rhs_contracting_dims = [0] - lhs_batch_dims = [] - rhs_batch_dims = [] - else: - # Promote 1D operands to 2D so the general path handles them. - if len(x.shape) == 1: - # (K,) -> (1, K) - x = ctx.build_op("reshape", [x], (1, x.shape[0]), x.dtype) - squeeze_lhs = True - - if len(y.shape) == 1: - # (K,) -> (K, 1) - y = ctx.build_op("reshape", [y], (y.shape[0], 1), y.dtype) - squeeze_rhs = True - - # General matrix multiplication (2D or batched) - assert x.shape[-1] == y.shape[-2], "Incompatible shapes for matmul" - - # Broadcast batch dimensions if needed - x_batch_shape = x.shape[:-2] - y_batch_shape = y.shape[:-2] - batch_shape = tuple(np.broadcast_shapes(x_batch_shape, y_batch_shape)) - result_shape = batch_shape + (x.shape[-2], y.shape[-1]) - - # If batch dimensions don't match, broadcast the operands first - target_x_shape = batch_shape + tuple(x.shape[-2:]) - target_y_shape = batch_shape + tuple(y.shape[-2:]) - - if x.shape != target_x_shape: - x = broadcast_to_shape_hlo(ctx, x, target_x_shape) - - if y.shape != target_y_shape: - y = broadcast_to_shape_hlo(ctx, y, target_y_shape) - - # Contracting dimensions - lhs_contracting_dims = [len(target_x_shape) - 1] - rhs_contracting_dims = [len(target_y_shape) - 2] - - # Batch dimensions - lhs_batch_dims = list(range(len(batch_shape))) - rhs_batch_dims = list(range(len(batch_shape))) - - # Build dot operation with dimension numbers - result_tensor = ctx.build_op( - "dot", - [x, y], - result_shape, - result_dtype, - { - "lhs_contracting_dimensions": lhs_contracting_dims, - "rhs_contracting_dimensions": rhs_contracting_dims, - "lhs_batch_dimensions": lhs_batch_dims, - "rhs_batch_dimensions": rhs_batch_dims, - }, - ) - - # Strip the dimensions that were added for 1D promotion. - if squeeze_lhs or squeeze_rhs: - final_shape = list(result_shape) - # squeeze_lhs removes the second-to-last dim (the prepended 1) - # squeeze_rhs removes the last dim (the appended 1) - # When both are true the result is already scalar from the 1D x 1D - # path above, so this branch won't fire for that case. - if squeeze_lhs: - final_shape.pop(-2) - if squeeze_rhs: - final_shape.pop(-1) - final_shape = tuple(final_shape) - result_tensor = ctx.build_op( - "reshape", [result_tensor], final_shape, result_dtype - ) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# dot -# ----------------------------------------------------------------------------- dot = Op("dot") - -@dot.impl("hlo") -def _dot_hlo(x, y, out=None): - """Dot product (HLO). - - Follows numpy.dot semantics: - - 1D x 1D: scalar inner product. - - 2D x 2D: matrix multiplication. - - N-D x 1D: sum product over the last axis of *x* and *y*. - - N-D x M-D (M>=2): sum product over the last axis of *x* and the - second-to-last axis of *y*. Non-contracted dimensions are - outer-producted (NOT broadcast like matmul). - """ - from nkipy.core.backend.hlo import ( - find_common_type_hlo, - get_hlo_context, - ) - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - result_dtype = find_common_type_hlo(x, y) - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - if isinstance(y, NKIPyTensorRef): - y = y.backend_tensor - - assert len(x.shape) >= 1 and len(y.shape) >= 1, "dot requires at least 1D arrays" - - # Contracting dimensions: last of x, second-to-last of y (or 0 if y is 1D) - lhs_contracting_dims = [len(x.shape) - 1] - rhs_contracting_dims = [max(0, len(y.shape) - 2)] - - assert x.shape[lhs_contracting_dims[0]] == y.shape[rhs_contracting_dims[0]], ( - f"shapes {x.shape} and {y.shape} not aligned" - ) - - # No batch dimensions for dot – remaining dims are outer-producted. - lhs_batch_dims = [] - rhs_batch_dims = [] - - # Result shape: non-contracted dims from x, then non-contracted dims from y - result_shape = tuple( - s for i, s in enumerate(x.shape) if i not in lhs_contracting_dims - ) + tuple(s for i, s in enumerate(y.shape) if i not in rhs_contracting_dims) - - result_tensor = ctx.build_op( - "dot", - [x, y], - result_shape, - result_dtype, - { - "lhs_contracting_dimensions": lhs_contracting_dims, - "rhs_contracting_dimensions": rhs_contracting_dims, - "lhs_batch_dimensions": lhs_batch_dims, - "rhs_batch_dimensions": rhs_batch_dims, - }, - ) - - return NKIPyTensorRef(result_tensor) - - # ----------------------------------------------------------------------------- -# norm - L2/Frobenius norm +# Composed linalg ops # ----------------------------------------------------------------------------- -norm = Op("norm") +norm = Op("norm") -@norm.impl("hlo") -def _norm_hlo(x, ord=None, axis=None, keepdims=False): - """L2/Frobenius norm: sqrt(sum(x*x, axis)). - Only supports ord=None or ord='fro'. - """ +@norm.composed_impl +def _norm(x, ord=None, axis=None, keepdims=False): + """L2/Frobenius norm: sqrt(sum(x*x, axis)).""" from nkipy.core.ops.binary import multiply from nkipy.core.ops.reduce import sum from nkipy.core.ops.unary import sqrt @@ -225,14 +40,11 @@ def _norm_hlo(x, ord=None, axis=None, keepdims=False): return sqrt(sum_squared) -# ----------------------------------------------------------------------------- -# outer - outer product of two vectors -# ----------------------------------------------------------------------------- outer = Op("outer") -@outer.impl("hlo") -def _outer_hlo(a, b, out=None): +@outer.composed_impl +def _outer(a, b, out=None): """Outer product: reshape a to (n, 1), b to (1, m), multiply.""" from nkipy.core.ops.binary import multiply from nkipy.core.ops.transform import reshape @@ -242,75 +54,4 @@ def _outer_hlo(a, b, out=None): return multiply(a_flat, b_flat) -# ----------------------------------------------------------------------------- -# trace - sum of diagonal elements -# ----------------------------------------------------------------------------- trace = Op("trace") - - -@trace.impl("hlo") -def _trace_hlo(a, offset=0, axis1=0, axis2=1, dtype=None): - """Sum of diagonal elements using iota indices. - - Create iota for rows and cols, compare for equality (diagonal mask), - multiply with input, sum. - """ - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.ops.binary import equal, multiply - from nkipy.core.ops.indexing import where - from nkipy.core.ops.reduce import sum - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(a, NKIPyTensorRef): - a_bt = a.backend_tensor - else: - a_bt = a - - # For 2D: create iota for rows and cols, mask diagonal, sum - shape = a_bt.shape - ndim = len(shape) - - # Normalize negative axes - if axis1 < 0: - axis1 += ndim - if axis2 < 0: - axis2 += ndim - - # Create iota along axis1 (row indices) - row_iota = ctx.build_op( - "iota", [], shape, np.dtype(np.int32), {"iota_dimension": axis1} - ) - row_ref = NKIPyTensorRef(row_iota) - - # Create iota along axis2 (col indices) and add offset - col_iota = ctx.build_op( - "iota", [], shape, np.dtype(np.int32), {"iota_dimension": axis2} - ) - col_ref = NKIPyTensorRef(col_iota) - - if offset != 0: - from nkipy.core.ops.binary import subtract - - col_ref = subtract(col_ref, offset) - - # Diagonal mask: row == col (with offset) - diag_mask = equal(row_ref, col_ref) - - # Apply mask: where diagonal, use a; else 0 - masked = where(diag_mask, a, 0.0) - - # Sum along both matrix axes - # Sum axis2 first (higher index), then axis1 - axes_to_reduce = sorted([axis1, axis2], reverse=True) - result = masked - for ax in axes_to_reduce: - result = sum(result, axis=ax) - - if dtype is not None: - from nkipy.core.ops.transform import astype - - result = astype(result, np.dtype(dtype)) - - return result diff --git a/nkipy/src/nkipy/core/ops/nn.py b/nkipy/src/nkipy/core/ops/nn.py index 69f1ef6..a79a1d1 100644 --- a/nkipy/src/nkipy/core/ops/nn.py +++ b/nkipy/src/nkipy/core/ops/nn.py @@ -11,150 +11,36 @@ # ----------------------------------------------------------------------------- softmax = Op("softmax") -# HLO implementation for softmax not implemented - # ----------------------------------------------------------------------------- # topk # ----------------------------------------------------------------------------- topk = Op("topk") +# ----------------------------------------------------------------------------- +# rms_norm +# ----------------------------------------------------------------------------- +rms_norm = Op("rms_norm") + @topk.impl("cpu") def _topk_cpu(x, k, axis=0, is_ascend=False, out=None, dtype=None): - """Top-k operation (CPU). - - Args: - x: Input array - k: Number of top elements to extract - axis: Axis along which to find top k elements - is_ascend: If True, find k smallest elements; if False, find k largest - out: Unused (for API compatibility) - dtype: Unused (for API compatibility) - - Returns: - Tuple of (values, indices) arrays - """ - # Normalize negative axis + """Top-k operation (CPU).""" if axis < 0: axis = x.ndim + axis if is_ascend: - # Find k smallest - use argpartition for efficiency - # Use k-1 as partition index to avoid out-of-bounds when k equals array size indices = np.argpartition(x, k - 1, axis=axis) indices = np.take(indices, range(k), axis=axis) - # Sort the k elements values = np.take_along_axis(x, indices, axis=axis) sort_indices = np.argsort(values, axis=axis) indices = np.take_along_axis(indices, sort_indices, axis=axis) values = np.take_along_axis(values, sort_indices, axis=axis) else: - # Find k largest - negate, partition, negate back - # Use k-1 as partition index to avoid out-of-bounds when k equals array size indices = np.argpartition(-x, k - 1, axis=axis) indices = np.take(indices, range(k), axis=axis) - # Sort the k elements in descending order values = np.take_along_axis(x, indices, axis=axis) sort_indices = np.argsort(-values, axis=axis) indices = np.take_along_axis(indices, sort_indices, axis=axis) values = np.take_along_axis(values, sort_indices, axis=axis) return values, indices.astype(np.uint32) - - -@topk.impl("hlo") -def _topk_hlo(x, k, axis=0, is_ascend=False, out=None, dtype=None): - """Top-k operation for HLO backend. - - Args: - x: Input tensor - k: Number of top elements to extract - axis: Axis along which to find top k elements - is_ascend: If True, find k smallest elements; if False, find k largest - out: Unused (for API compatibility) - dtype: Unused (for API compatibility) - - Returns: - Tuple of (values, indices) tensors - """ - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - if axis != -1 and axis != x.ndim - 1: - raise NotImplementedError("the custom TopK op only supports last axis") - - ctx = get_hlo_context() - - # Convert NKIPyTensorRef to HLOTensor if needed - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # Normalize negative axis - if axis < 0: - axis = len(x.shape) + axis - - # Validate k - assert x.shape[axis] >= k, ( - f"k={k} must be <= size of axis {axis} which is {x.shape[axis]}" - ) - - output_shape = list(x.shape) - output_shape[axis] = k - output_shape = tuple(output_shape) - - input_for_topk = x - # Build TopK operation - # Note: HLO TopK always returns largest elements, so for smallest we need to negate - if is_ascend: - # For ascending (smallest k), negate the input - input_for_topk = ctx.build_op( - "negate", [input_for_topk], input_for_topk.shape, input_for_topk.dtype - ) - - # Create the TopK operation - topk_output_shape = list(input_for_topk.shape) - topk_output_shape[-1] = k - topk_output_shape = tuple(topk_output_shape) - - # Build TopK - it returns a tuple of (values, indices) - topk_tuple = ctx.build_op( - "topk", - [input_for_topk], - topk_output_shape, - x.dtype, - {"k": k, "largest": True, "is_tuple": True}, - ) - - # Extract values (tuple element 0) - values_tensor = ctx.build_op( - "get-tuple-element", - [topk_tuple], - topk_output_shape, - x.dtype, - {"tuple_index": 0}, - ) - - # Extract indices (tuple element 1) - indices_tensor = ctx.build_op( - "get-tuple-element", - [topk_tuple], - topk_output_shape, - np.dtype(np.uint32), - {"tuple_index": 1}, - ) - - # If we negated for ascending order, negate the values back - if is_ascend: - values_tensor = ctx.build_op( - "negate", [values_tensor], topk_output_shape, x.dtype - ) - - return NKIPyTensorRef(values_tensor), NKIPyTensorRef(indices_tensor) - - -# ----------------------------------------------------------------------------- -# rms_norm -# ----------------------------------------------------------------------------- -rms_norm = Op("rms_norm") - -# HLO implementation for rmsnorm not implemented diff --git a/nkipy/src/nkipy/core/ops/reduce.py b/nkipy/src/nkipy/core/ops/reduce.py index cb1dc74..ea669e9 100644 --- a/nkipy/src/nkipy/core/ops/reduce.py +++ b/nkipy/src/nkipy/core/ops/reduce.py @@ -6,91 +6,6 @@ from nkipy.core.ops._registry import Op -# ============================================================================= -# HLO Implementation -# ============================================================================= - - -def _build_reduction_hlo( - x, np_op, axis=None, out=None, dtype=None, keepdims=False, initial=None -): - """Build a reduction HLO operation.""" - from nkipy.core.backend.hlo import as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # Map numpy reduction ops to HLO reduce computation names - reduce_op_map = { - np.sum: "add", - np.max: "maximum", - np.min: "minimum", - np.prod: "multiply", - } - - if np_op not in reduce_op_map: - raise NotImplementedError( - f"Reduction operation {np_op} not yet supported in HLO tracing" - ) - - hlo_op = reduce_op_map[np_op] - - # Handle axis parameter - normalize to tuple - if axis is None: - dimensions_to_reduce = tuple(range(len(x.shape))) - elif isinstance(axis, int): - dim = axis if axis >= 0 else len(x.shape) + axis - dimensions_to_reduce = (dim,) - elif isinstance(axis, (list, tuple)): - dimensions_to_reduce = tuple( - ax if ax >= 0 else len(x.shape) + ax for ax in axis - ) - else: - dimensions_to_reduce = (axis,) - - # Calculate output shape - HLO reduce always removes dimensions - reduced_shape = tuple( - s for i, s in enumerate(x.shape) if i not in dimensions_to_reduce - ) - - # Create init value based on operation as a constant tensor - init_values = { - "add": 0.0, - "maximum": float("-inf"), - "minimum": float("inf"), - "multiply": 1.0, - } - init_value = init_values[hlo_op] - - # Create init value as a scalar constant tensor - init_tensor = as_hlo_tensor(ctx, init_value, x.dtype) - - # Build the reduce operation - result_tensor = ctx.build_op( - "reduce", - [x, init_tensor], - reduced_shape, - x.dtype, - { - "dimensions": list(dimensions_to_reduce), - "computation": hlo_op, - }, - ) - - # If keepdims is True, reshape to add back the reduced dimensions as size 1 - if keepdims: - keepdims_shape = tuple( - 1 if i in dimensions_to_reduce else s for i, s in enumerate(x.shape) - ) - result_tensor = ctx.build_op( - "reshape", [result_tensor], keepdims_shape, x.dtype - ) - - return NKIPyTensorRef(result_tensor) - def _calculate_reduction_count(x_shape, axis): """Calculate the number of elements being reduced.""" @@ -108,230 +23,26 @@ def _calculate_reduction_count(x_shape, axis): # ----------------------------------------------------------------------------- -# Factory function for reduction ops -# ----------------------------------------------------------------------------- -def _make_reduction_op(name: str, np_op) -> Op: - """Create a reduction Op with IR and HLO implementations.""" - op = Op(name) - - @op.impl("hlo") - def _impl_hlo(x, axis=None, out=None, dtype=None, keepdims=False, initial=None): - return _build_reduction_hlo( - x, - np_op, - axis=axis, - out=out, - dtype=dtype, - keepdims=keepdims, - initial=initial, - ) - - return op - - -# ----------------------------------------------------------------------------- -# Reduction operations -# ----------------------------------------------------------------------------- -sum = _make_reduction_op("sum", np.sum) -max = _make_reduction_op("max", np.max) -min = _make_reduction_op("min", np.min) -prod = _make_reduction_op("prod", np.prod) - -# mean: sum / count - simplified using ops -mean = Op("mean") - - -@mean.impl("hlo") -def _mean_hlo(x, axis=None, out=None, dtype=None, keepdims=False, initial=None): - """Mean operation: sum(x) / count.""" - from nkipy.core.ops.binary import divide - - if dtype is not None: - from nkipy.core.ops.transform import astype - - x = astype(x, np.dtype(dtype)) - - sum_result = sum(x, axis=axis, keepdims=keepdims) - count = _calculate_reduction_count(x.shape, axis) - - return divide(sum_result, float(count)) - - -# ----------------------------------------------------------------------------- -# any - special reduction that checks if any element is True -# ----------------------------------------------------------------------------- -any = Op("any") - - -@any.impl("hlo") -def _any_hlo(x, axis=None, out=None, dtype=None, keepdims=False): - """Check if any element is True along the given axis. - - any(x) = (sum(x != 0) != 0) - """ - from nkipy.core.ops.binary import not_equal - from nkipy.core.ops.transform import astype - - non_zero = not_equal(x, 0) - non_zero_i32 = astype(non_zero, np.dtype(np.int32)) - summed = sum(non_zero_i32, axis=axis, keepdims=keepdims) - - return not_equal(summed, 0) - - -# ----------------------------------------------------------------------------- -# var - variance: mean((x - mean(x))^2) -# ----------------------------------------------------------------------------- -var = Op("var") - - -@var.impl("hlo") -def _var_hlo(x, axis=None, out=None, dtype=None, keepdims=False, ddof=0): - """Variance operation: sum((x - mean(x))^2) / (N - ddof).""" - from nkipy.core.ops.binary import divide, multiply, subtract - - if dtype is not None: - from nkipy.core.ops.transform import astype - - x = astype(x, np.dtype(dtype)) - - # Compute mean with keepdims=True so it broadcasts back against x - mean_x = mean(x, axis=axis, keepdims=True) - - centered = subtract(x, mean_x) - squared = multiply(centered, centered) - sum_squared = sum(squared, axis=axis, keepdims=keepdims) - count = _calculate_reduction_count(x.shape, axis) - - denom = count - ddof if count - ddof > 0 else 0 - return divide(sum_squared, float(denom)) - - -# ----------------------------------------------------------------------------- -# std - standard deviation: sqrt(var(x)) -# ----------------------------------------------------------------------------- -std = Op("std") - - -@std.impl("hlo") -def _std_hlo(x, axis=None, out=None, dtype=None, keepdims=False, ddof=0): - """Standard deviation: sqrt(var(x, axis, ddof)).""" - from nkipy.core.ops.unary import sqrt - - return sqrt(var(x, axis=axis, dtype=dtype, keepdims=keepdims, ddof=ddof)) - - -# ----------------------------------------------------------------------------- -# argmax - index of maximum value along an axis +# Primitive reduction operations # ----------------------------------------------------------------------------- +sum = Op("sum") +max = Op("max") +min = Op("min") +prod = Op("prod") argmax = Op("argmax") - - -@argmax.impl("hlo") -def _argmax_hlo(x, axis=None, out=None, keepdims=False): - """Argmax: find index of maximum value along axis. - - Strategy: max_val → mask where equal → create iota indices → where(mask, iota, large) → min - """ - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_ref = x - x_bt = x.backend_tensor - else: - x_ref = NKIPyTensorRef(x) - x_bt = x - - # Save original shape/axis for keepdims support - original_axis = axis - x_ref_original_shape = x_ref.shape - - # If axis is None, flatten first - if axis is None: - from nkipy.core.ops.transform import reshape - - total = int(np.prod(x_ref.shape)) - x_ref = reshape(x_ref, (total,)) - x_bt = x_ref.backend_tensor - axis = 0 - - # Normalize negative axis - ndim = len(x_bt.shape) - if axis < 0: - axis = ndim + axis - - # Step 1: Get max value along axis (keepdims for broadcasting) - max_val = max(x_ref, axis=axis, keepdims=True) - - # Step 2: Create boolean mask where x == max_val - from nkipy.core.ops.binary import equal - - mask = equal(x_ref, max_val) - - # Step 3: Create iota indices along the target axis as float - # (using float avoids int32 overflow with min's inf init value) - iota_tensor = ctx.build_op( - "iota", - [], - x_bt.shape, - np.dtype(np.float32), - {"iota_dimension": axis}, - ) - iota_ref = NKIPyTensorRef(iota_tensor) - - # Step 4: Where mask is true, use iota index; else use large float value - large_val = float(x_bt.shape[axis] + 1) - from nkipy.core.ops.indexing import where - - masked_indices = where(mask, iota_ref, large_val) - - # Step 5: Min along axis to find the first occurrence - result_float = min(masked_indices, axis=axis) - - # NKI hardware uses int32 for indices; NumPy returns int64 - from nkipy.core.ops.transform import astype - - result = astype(result_float, np.dtype(np.int32)) - - if keepdims: - from nkipy.core.ops.transform import reshape - - if original_axis is not None: - keepdims_shape = list(x_ref_original_shape) - keepdims_shape[original_axis] = 1 - result = reshape(result, tuple(keepdims_shape)) - else: - # axis=None flattened → restore all dims as size-1 - keepdims_shape = tuple(1 for _ in x_ref_original_shape) - result = reshape(result, keepdims_shape) - - return result - - -# ----------------------------------------------------------------------------- -# cumsum - cumulative sum along an axis -# ----------------------------------------------------------------------------- +argmin = Op("argmin") cumsum = Op("cumsum") -@cumsum.impl("hlo") -def _cumsum_hlo(x, axis=None, dtype=None): - """Cumulative sum via triangular matrix multiplication. - - For axis of size N: create NxN upper-triangular ones matrix, matmul. - Multi-dim: transpose target axis to last, reshape to 2D, apply, reshape+transpose back. - """ +@cumsum.composed_impl +def _cumsum(x, axis=None, dtype=None): from nkipy.core.ops.creation import constant from nkipy.core.ops.linalg import matmul - from nkipy.core.ops.transform import reshape, transpose + from nkipy.core.ops.transform import astype, reshape, transpose x_shape = x.shape ndim = len(x_shape) - # Flatten when axis is None (matches numpy behavior) if axis is None: total = int(np.prod(x_shape)) x = reshape(x, (total,)) @@ -343,28 +54,19 @@ def _cumsum_hlo(x, axis=None, dtype=None): N = x_shape[axis] - # Create upper-triangular ones matrix (N, N) - # M[j,i] = 1 if j <= i, so x @ M gives cumsum - # Always use float32 for numerical stability; dtype conversion happens after. - # FIXME: when input is integer and dtype is not specified, the result will be - # float32 instead of preserving the input dtype (NumPy preserves it). tri = np.triu(np.ones((N, N), dtype=np.float32)) tri_const = constant(tri) - # If axis is the last dimension and tensor is 2D, simple case if ndim == 1: - # (1, N) @ (N, N) -> (1, N) -> (N,) x_2d = reshape(x, (1, N)) result_2d = matmul(x_2d, tri_const) result = reshape(result_2d, (N,)) elif axis == ndim - 1: - # Reshape to (..., N), matmul with (N, N), reshape back batch_size = int(np.prod(x_shape[:-1])) x_2d = reshape(x, (batch_size, N)) result_2d = matmul(x_2d, tri_const) result = reshape(result_2d, x_shape) else: - # Transpose target axis to last, apply, transpose back perm = list(range(ndim)) perm[axis], perm[-1] = perm[-1], perm[axis] x_t = transpose(x, axes=perm) @@ -376,108 +78,86 @@ def _cumsum_hlo(x, axis=None, dtype=None): result_t = reshape(result_2d, x_t_shape) result = transpose(result_t, axes=perm) - # Cast to requested dtype if needed if dtype is not None and result.dtype != np.dtype(dtype): - from nkipy.core.ops.transform import astype - result = astype(result, np.dtype(dtype)) return result - # ----------------------------------------------------------------------------- -# argmin - index of minimum value along an axis +# Composed reduction operations # ----------------------------------------------------------------------------- -argmin = Op("argmin") +mean = Op("mean") -@argmin.impl("hlo") -def _argmin_hlo(x, axis=None, out=None, keepdims=False): - """Argmin: find index of minimum value along axis. - Strategy: min_val → mask where equal → create iota indices → where(mask, iota, large) → min - """ - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef +@mean.composed_impl +def _mean(x, axis=None, out=None, dtype=None, keepdims=False, initial=None): + from nkipy.core.ops.binary import divide - ctx = get_hlo_context() + if dtype is not None: + from nkipy.core.ops.transform import astype - if isinstance(x, NKIPyTensorRef): - x_ref = x - x_bt = x.backend_tensor - else: - x_ref = NKIPyTensorRef(x) - x_bt = x + x = astype(x, np.dtype(dtype)) - original_axis = axis - x_ref_original_shape = x_ref.shape + sum_result = sum(x, axis=axis, keepdims=keepdims) + count = _calculate_reduction_count(x.shape, axis) - if axis is None: - from nkipy.core.ops.transform import reshape + return divide(sum_result, float(count)) - total = int(np.prod(x_ref.shape)) - x_ref = reshape(x_ref, (total,)) - x_bt = x_ref.backend_tensor - axis = 0 - ndim = len(x_bt.shape) - if axis < 0: - axis = ndim + axis +any = Op("any") - # Step 1: Get min value along axis (keepdims for broadcasting) - min_val = min(x_ref, axis=axis, keepdims=True) - # Step 2: Create boolean mask where x == min_val - from nkipy.core.ops.binary import equal +@any.composed_impl +def _any(x, axis=None, out=None, dtype=None, keepdims=False): + from nkipy.core.ops.binary import not_equal + from nkipy.core.ops.transform import astype - mask = equal(x_ref, min_val) + non_zero = not_equal(x, 0) + non_zero_i32 = astype(non_zero, np.dtype(np.int32)) + summed = sum(non_zero_i32, axis=axis, keepdims=keepdims) - # Step 3: Create iota indices along the target axis as float - iota_tensor = ctx.build_op( - "iota", - [], - x_bt.shape, - np.dtype(np.float32), - {"iota_dimension": axis}, - ) - iota_ref = NKIPyTensorRef(iota_tensor) + return not_equal(summed, 0) - # Step 4: Where mask is true, use iota index; else use large float value - large_val = float(x_bt.shape[axis] + 1) - from nkipy.core.ops.indexing import where - masked_indices = where(mask, iota_ref, large_val) +var = Op("var") - # Step 5: Min along axis to find the first occurrence - result_float = min(masked_indices, axis=axis) - from nkipy.core.ops.transform import astype +@var.composed_impl +def _var(x, axis=None, out=None, dtype=None, keepdims=False, ddof=0): + from nkipy.core.ops.binary import divide, multiply, subtract - result = astype(result_float, np.dtype(np.int32)) + if dtype is not None: + from nkipy.core.ops.transform import astype - if keepdims: - from nkipy.core.ops.transform import reshape + x = astype(x, np.dtype(dtype)) - if original_axis is not None: - keepdims_shape = list(x_ref_original_shape) - keepdims_shape[original_axis] = 1 - result = reshape(result, tuple(keepdims_shape)) - else: - keepdims_shape = tuple(1 for _ in x_ref_original_shape) - result = reshape(result, keepdims_shape) + mean_x = mean(x, axis=axis, keepdims=True) - return result + centered = subtract(x, mean_x) + squared = multiply(centered, centered) + sum_squared = sum(squared, axis=axis, keepdims=keepdims) + count = _calculate_reduction_count(x.shape, axis) + + denom = count - ddof if count - ddof > 0 else 0 + return divide(sum_squared, float(denom)) + + +std = Op("std") + + +@std.composed_impl +def _std(x, axis=None, out=None, dtype=None, keepdims=False, ddof=0): + from nkipy.core.ops.unary import sqrt + + return sqrt(var(x, axis=axis, dtype=dtype, keepdims=keepdims, ddof=ddof)) -# ----------------------------------------------------------------------------- -# count_nonzero - count non-zero elements along an axis -# ----------------------------------------------------------------------------- count_nonzero = Op("count_nonzero") -@count_nonzero.impl("hlo") -def _count_nonzero_hlo(x, axis=None, keepdims=False): - """Count non-zero elements: sum(x != 0).""" +@count_nonzero.composed_impl +def _count_nonzero(x, axis=None, keepdims=False): from nkipy.core.ops.binary import not_equal from nkipy.core.ops.transform import astype diff --git a/nkipy/src/nkipy/core/ops/transform.py b/nkipy/src/nkipy/core/ops/transform.py index 0ffa262..353d8d6 100644 --- a/nkipy/src/nkipy/core/ops/transform.py +++ b/nkipy/src/nkipy/core/ops/transform.py @@ -4,873 +4,30 @@ reshape, transpose, expand_dims, concatenate, split, copy, repeat """ -import numpy as np - from nkipy.core.ops._registry import Op # ----------------------------------------------------------------------------- -# reshape +# Primitive transform ops # ----------------------------------------------------------------------------- reshape = Op("reshape") - - -@reshape.impl("hlo") -def _reshape_hlo(x, newshape): - """Reshape tensor to new shape (HLO).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # Normalize shape (handle -1) - if isinstance(newshape, int): - newshape = (newshape,) - - # Handle -1 in newshape - if -1 in newshape: - total_size = int(np.prod(x.shape)) - known_size = int(np.prod([d for d in newshape if d != -1])) - assert known_size > 0, "Cannot reshape to a size of 0" - assert total_size % known_size == 0, ( - f"Cannot reshape array of size {total_size} into shape {newshape}" - ) - newshape = tuple(total_size // known_size if d == -1 else d for d in newshape) - - # Verify total size matches - if np.prod(x.shape) != np.prod(newshape): - raise ValueError( - f"Cannot reshape array of size {np.prod(x.shape)} into shape {newshape}" - ) - - result_tensor = ctx.build_op("reshape", [x], newshape, x.dtype) - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# transpose -# ----------------------------------------------------------------------------- transpose = Op("transpose") - - -@transpose.impl("hlo") -def _transpose_hlo(x, axes=None, out=None, dtype=None): - """Transpose tensor (HLO).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # Handle default axes (reverse all dimensions) - if axes is None: - axes = list(range(len(x.shape)))[::-1] - - # Calculate output shape - result_shape = tuple(x.shape[i] for i in axes) - - result_tensor = ctx.build_op( - "transpose", [x], result_shape, x.dtype, {"permutation": axes} - ) - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# swapaxes -# ----------------------------------------------------------------------------- -swapaxes = Op("swapaxes") - - -@swapaxes.impl("hlo") -def _swapaxes_hlo(x, axis1, axis2): - """Swap two axes of a tensor by delegating to transpose.""" - ndim = len(x.shape) - if axis1 < 0: - axis1 += ndim - if axis2 < 0: - axis2 += ndim - if not (0 <= axis1 < ndim) or not (0 <= axis2 < ndim): - raise np.AxisError(f"axis is out of bounds for array of dimension {ndim}") - axes = list(range(ndim)) - axes[axis1], axes[axis2] = axes[axis2], axes[axis1] - return transpose(x, axes=axes) - - -# ----------------------------------------------------------------------------- -# stack -# ----------------------------------------------------------------------------- -stack = Op("stack") - - -@stack.impl("hlo") -def _stack_hlo(arrays, axis=0, out=None, dtype=None): - """Stack tensors along a new axis using expand_dims + concatenate.""" - expanded = [expand_dims(a, axis=axis) for a in arrays] - return concatenate(expanded, axis=axis) - - -# ----------------------------------------------------------------------------- -# expand_dims -# ----------------------------------------------------------------------------- expand_dims = Op("expand_dims") - - -@expand_dims.impl("hlo") -def _expand_dims_hlo(x, axis): - """Expand dimensions of tensor (HLO).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - rank = len(x.shape) - - # Handle both single axis and list of axes - if isinstance(axis, (list, tuple)): - # Final rank after all dimensions are added - final_rank = rank + len(axis) - - # Normalize negative axes relative to final rank - axes = [] - for ax in axis: - if ax < 0: - ax = final_rank + ax - if ax < 0 or ax > final_rank - 1: - raise ValueError( - f"axis {ax} is out of bounds for array of dimension {final_rank}" - ) - axes.append(ax) - - # Check for duplicate axes - if len(axes) != len(set(axes)): - raise ValueError("repeated axis in expand_dims") - - axes = sorted(axes) - - new_shape = list(x.shape) - for ax in axes: - new_shape.insert(ax, 1) - new_shape = tuple(new_shape) - else: - if axis < 0: - axis = rank + axis + 1 - - if axis < 0 or axis > rank: - raise ValueError( - f"axis {axis} is out of bounds for array of dimension {rank}" - ) - - new_shape = list(x.shape) - new_shape.insert(axis, 1) - new_shape = tuple(new_shape) - - result_tensor = ctx.build_op("reshape", [x], new_shape, x.dtype) - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# concatenate -# ----------------------------------------------------------------------------- concatenate = Op("concatenate") - - -@concatenate.impl("hlo") -def _concatenate_hlo(tensors, axis=0): - """Concatenate tensors along axis (HLO).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - # Convert all tensors to HLOTensor - hlo_tensors = [] - for t in tensors: - if isinstance(t, NKIPyTensorRef): - hlo_tensors.append(t.backend_tensor) - elif isinstance(t, np.ndarray): - # Convert concrete np.ndarray to HLO constant - from nkipy.core.ops.creation import constant - - const_ref = constant(t) - hlo_tensors.append(const_ref.backend_tensor) - else: - hlo_tensors.append(t) - - if not hlo_tensors: - raise ValueError("Need at least one tensor to concatenate") - - if len(hlo_tensors) == 1: - result_tensor = ctx.build_op( - "copy", [hlo_tensors[0]], hlo_tensors[0].shape, hlo_tensors[0].dtype - ) - return NKIPyTensorRef(result_tensor) - - # Normalize negative axis - ndim = len(hlo_tensors[0].shape) - if axis < 0: - axis = ndim + axis - - if axis < 0 or axis >= ndim: - raise ValueError(f"axis {axis} is out of bounds for array of dimension {ndim}") - - # Calculate output shape - output_shape = list(hlo_tensors[0].shape) - output_shape[axis] = sum(t.shape[axis] for t in hlo_tensors) - output_shape = tuple(output_shape) - - # Get common dtype - dtype = hlo_tensors[0].dtype - for t in hlo_tensors[1:]: - dtype = np.result_type(dtype, t.dtype) - - result_tensor = ctx.build_op( - "concatenate", hlo_tensors, output_shape, dtype, {"dimension": axis} - ) - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# split -# ----------------------------------------------------------------------------- split = Op("split") - - -@split.impl("hlo") -def _split_hlo(x, indices_or_sections, axis=0): - """Split tensor into multiple tensors (HLO).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # Normalize negative axis - if axis < 0: - axis = len(x.shape) + axis - - if axis < 0 or axis >= len(x.shape): - raise ValueError( - f"axis {axis} is out of bounds for array of dimension {len(x.shape)}" - ) - - axis_size = x.shape[axis] - - # Determine split points - if isinstance(indices_or_sections, int): - n_sections = indices_or_sections - if n_sections <= 0: - raise ValueError("Number of sections must be larger than 0") - if axis_size % n_sections != 0: - raise ValueError("Array split does not result in an equal division") - - section_size = axis_size // n_sections - split_indices = [i * section_size for i in range(1, n_sections)] - else: - split_indices = list(indices_or_sections) - - split_points = [0] + split_indices + [axis_size] - - result_tensors = [] - for i in range(len(split_points) - 1): - start_idx = split_points[i] - end_idx = split_points[i + 1] - - start_indices = [0] * len(x.shape) - limit_indices = list(x.shape) - strides = [1] * len(x.shape) - - start_indices[axis] = start_idx - limit_indices[axis] = end_idx - - slice_shape = list(x.shape) - slice_shape[axis] = end_idx - start_idx - slice_shape = tuple(slice_shape) - - slice_tensor = ctx.build_op( - "slice", - [x], - slice_shape, - x.dtype, - { - "start_indices": start_indices, - "limit_indices": limit_indices, - "strides": strides, - }, - ) - - result_tensors.append(NKIPyTensorRef(slice_tensor)) - - return result_tensors - - -# ----------------------------------------------------------------------------- -# copy -# ----------------------------------------------------------------------------- copy = Op("copy") - - -@copy.impl("hlo") -def _copy_hlo(x, out=None, dtype=None): - """Copy tensor (HLO).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - result_tensor = ctx.build_op("copy", [x], x.shape, x.dtype) - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# repeat -# ----------------------------------------------------------------------------- repeat = Op("repeat") - - -@repeat.impl("hlo") -def _repeat_hlo(x, repeats, axis=None): - """Repeat elements of tensor (HLO).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # Handle axis=None case - flatten the array first - if axis is None: - flattened_shape = (int(np.prod(x.shape)),) - x = ctx.build_op("reshape", [x], flattened_shape, x.dtype) - axis = 0 - - # Normalize negative axis - if axis < 0: - axis = len(x.shape) + axis - - if not isinstance(repeats, (int, np.integer)): - raise TypeError( - f"Only compile-time-known integer repeats are supported, got {type(repeats).__name__}. " - "Dynamic tensor repeats are not supported in tracing." - ) - repeats = int(repeats) - - # Calculate output shape - new_shape = list(x.shape) - new_shape[axis] *= repeats - new_shape = tuple(new_shape) - - # Strategy: broadcast then reshape - # 1. Build broadcast shape by inserting repeats dimension after axis - broadcast_shape = list(x.shape) - broadcast_shape.insert(axis + 1, repeats) - broadcast_shape = tuple(broadcast_shape) - - # 2. Broadcast x to the expanded shape - # broadcast_dims maps each dim of x to the corresponding dim in broadcast_shape - # (skipping the newly inserted repeat dimension at axis+1) - broadcast_dims = [i if i <= axis else i + 1 for i in range(len(x.shape))] - x_broadcast = ctx.build_op( - "broadcast", - [x], - broadcast_shape, - x.dtype, - {"broadcast_dimensions": broadcast_dims}, - ) - - # 3. Reshape to merge axis and repeat dimensions - result_tensor = ctx.build_op("reshape", [x_broadcast], new_shape, x.dtype) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# broadcast_to -# ----------------------------------------------------------------------------- broadcast_to = Op("broadcast_to") - - -@broadcast_to.impl("hlo") -def _broadcast_to_hlo(x, shape, out=None, dtype=None): - """Broadcast tensor to target shape (HLO).""" - from nkipy.core.backend.hlo import broadcast_to_shape_hlo, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # Convert shape to tuple if needed - if isinstance(shape, int): - shape = (shape,) - target_shape = tuple(shape) - - # If shapes are already the same, just return a copy - if x.shape == target_shape: - result_tensor = ctx.build_op("copy", [x], x.shape, x.dtype) - return NKIPyTensorRef(result_tensor) - - result_tensor = broadcast_to_shape_hlo(ctx, x, target_shape) - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# copyto -# ----------------------------------------------------------------------------- -copyto = Op("copyto") - -# Note: copyto for HLO is deprecated due to in-place semantics - - -# ----------------------------------------------------------------------------- -# squeeze -# ----------------------------------------------------------------------------- -squeeze = Op("squeeze") - - -@squeeze.impl("hlo") -def _squeeze_hlo(x, axis=None): - """Remove size-1 dimensions from tensor.""" - x_shape = x.shape - ndim = len(x_shape) - - if axis is None: - # Remove all size-1 dimensions - new_shape = tuple(s for s in x_shape if s != 1) - if not new_shape: - new_shape = () - else: - if isinstance(axis, int): - axis = (axis,) - # Normalize negative axes - axes = tuple(a if a >= 0 else ndim + a for a in axis) - # Validate - for a in axes: - if x_shape[a] != 1: - raise ValueError( - f"cannot select an axis to squeeze out which has size " - f"not equal to one, got shape[{a}] = {x_shape[a]}" - ) - new_shape = tuple(s for i, s in enumerate(x_shape) if i not in axes) - - if new_shape == x_shape: - return x - return reshape(x, new_shape) - - -# ----------------------------------------------------------------------------- -# astype -# ----------------------------------------------------------------------------- astype = Op("astype") - - -@astype.impl("hlo") -def _astype_hlo(x, dtype): - """Convert tensor to specified dtype (HLO).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # If dtype is the same, just copy - if x.dtype == dtype: - result_tensor = ctx.build_op("copy", [x], x.shape, x.dtype) - else: - result_tensor = ctx.build_op("convert", [x], x.shape, dtype) - - return NKIPyTensorRef(result_tensor) - - -# ----------------------------------------------------------------------------- -# pad -# ----------------------------------------------------------------------------- +squeeze = Op("squeeze") pad = Op("pad") - - -@pad.impl("hlo") -def _pad_hlo(x, pad_width, mode="constant", constant_values=0, **kwargs): - """Pad tensor with various modes. - - Supports: - - mode='constant': Uses native HLO pad instruction - - mode='edge': Composes from slice + concatenate - """ - from nkipy.core.backend.hlo import as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_shape = x.shape - x_dtype = x.dtype - x_bt = x.backend_tensor - else: - x_shape = x.shape - x_dtype = x.dtype - x_bt = x - - ndim = len(x_shape) - - # Normalize pad_width to list of (before, after) tuples - pad_width = np.asarray(pad_width) - if pad_width.ndim == 0: - pad_width = np.broadcast_to(pad_width, (ndim, 2)) - elif pad_width.ndim == 1: - if len(pad_width) == 2: - pad_width = np.broadcast_to(pad_width, (ndim, 2)) - else: - pad_width = np.array([[p, p] for p in pad_width]) - if len(pad_width) != ndim: - raise ValueError( - f"pad_width must have length {ndim} to match array dimensions, " - f"got {len(pad_width)}" - ) - # Broadcast short 2D pad_width to all dims (e.g. np.pad(a_2d, ((1,2),))) - if pad_width.ndim == 2 and len(pad_width) == 1: - pad_width = np.broadcast_to(pad_width, (ndim, 2)) - pad_width = [(int(pad_width[i, 0]), int(pad_width[i, 1])) for i in range(ndim)] - - if mode == "constant": - # Use native HLO pad instruction - # Build padding_config: list of (low, high, interior) per dim - padding_config = [(low, high, 0) for low, high in pad_width] - - # Calculate output shape - result_shape = tuple( - s + low + high for s, (low, high) in zip(x_shape, pad_width) - ) - - # Create padding value scalar - pad_value_tensor = as_hlo_tensor(ctx, constant_values, x_dtype) - - result_tensor = ctx.build_op( - "pad", - [x_bt, pad_value_tensor], - result_shape, - x_dtype, - {"padding_config": padding_config}, - ) - return NKIPyTensorRef(result_tensor) - - elif mode == "edge": - # Compose from slice + concatenate for each dimension - result = NKIPyTensorRef(x_bt) if not isinstance(x, NKIPyTensorRef) else x - for dim in range(ndim): - before, after = pad_width[dim] - if before == 0 and after == 0: - continue - - parts = [] - if before > 0: - # Slice the first element along this dim and repeat - edge_slice = _slice_single(result, dim, 0) - edge_expanded = expand_dims(edge_slice, axis=dim) - edge_repeated = repeat(edge_expanded, before, axis=dim) - parts.append(edge_repeated) - - parts.append(result) - - if after > 0: - # Slice the last element along this dim and repeat - last_idx = result.shape[dim] - 1 - edge_slice = _slice_single(result, dim, last_idx) - edge_expanded = expand_dims(edge_slice, axis=dim) - edge_repeated = repeat(edge_expanded, after, axis=dim) - parts.append(edge_repeated) - - result = concatenate(parts, axis=dim) - - return result - - else: - raise NotImplementedError( - f"Pad mode '{mode}' is not supported. Only 'constant' and 'edge' modes are available." - ) - - -# ----------------------------------------------------------------------------- -# diff -# ----------------------------------------------------------------------------- +swapaxes = Op("swapaxes") +stack = Op("stack") diff = Op("diff") - - -@diff.impl("hlo") -def _diff_hlo(a, n=1, axis=-1, prepend=None, append=None): - """Compute the n-th discrete difference along the given axis.""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.ops.binary import subtract - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - ndim = len(a.shape) - if axis < 0: - axis += ndim - - result = a - for _ in range(n): - if isinstance(result, NKIPyTensorRef): - r_bt = result.backend_tensor - else: - r_bt = result - - axis_size = r_bt.shape[axis] - - # Slice for a[..., 1:] along axis - start1 = [0] * ndim - limit1 = list(r_bt.shape) - start1[axis] = 1 - shape1 = list(r_bt.shape) - shape1[axis] = axis_size - 1 - - t1 = ctx.build_op( - "slice", - [r_bt], - tuple(shape1), - r_bt.dtype, - {"start_indices": start1, "limit_indices": limit1, "strides": [1] * ndim}, - ) - - # Slice for a[..., :-1] along axis - start0 = [0] * ndim - limit0 = list(r_bt.shape) - limit0[axis] = axis_size - 1 - shape0 = list(r_bt.shape) - shape0[axis] = axis_size - 1 - - t0 = ctx.build_op( - "slice", - [r_bt], - tuple(shape0), - r_bt.dtype, - {"start_indices": start0, "limit_indices": limit0, "strides": [1] * ndim}, - ) - - result = subtract(NKIPyTensorRef(t1), NKIPyTensorRef(t0)) - - return result - - -# ----------------------------------------------------------------------------- -# flip -# ----------------------------------------------------------------------------- flip = Op("flip") - - -@flip.impl("hlo") -def _flip_hlo(x, axis=None): - """Reverse elements along one or more axes using reversed gather indices.""" - from nkipy.core.ops.indexing import take - - ndim = len(x.shape) - - if axis is None: - axes = list(range(ndim)) - elif isinstance(axis, int): - axes = [axis if axis >= 0 else axis + ndim] - else: - axes = [a if a >= 0 else a + ndim for a in axis] - - result = x - for ax in axes: - n = result.shape[ax] - reversed_indices = np.arange(n - 1, -1, -1, dtype=np.int32) - result = take(result, reversed_indices, axis=ax) - - return result - - -# ----------------------------------------------------------------------------- -# tile -# ----------------------------------------------------------------------------- tile = Op("tile") - - -@tile.impl("hlo") -def _tile_hlo(x, reps): - """Tile tensor by repeating it along each axis. - - Strategy: reshape to interleave rep dims, broadcast, reshape to merge. - For shape (A, B) with reps (r0, r1): - reshape to (1, A, 1, B) → broadcast to (r0, A, r1, B) → reshape to (r0*A, r1*B) - """ - if isinstance(reps, int): - reps = (reps,) - reps = tuple(reps) - - x_shape = x.shape - ndim = len(x_shape) - - # Pad reps or shape if they have different lengths (numpy behavior) - if len(reps) < ndim: - reps = (1,) * (ndim - len(reps)) + reps - elif len(reps) > ndim: - x = reshape(x, (1,) * (len(reps) - ndim) + x_shape) - x_shape = x.shape - ndim = len(x_shape) - - # If all reps are 1, just return a copy - if all(r == 1 for r in reps): - return copy(x) - - # Build interleaved shape: (r0, s0, r1, s1, ...) - interleaved = [] - for r, s in zip(reps, x_shape): - interleaved.append(1) - interleaved.append(s) - result = reshape(x, tuple(interleaved)) - - # Broadcast rep dims: (r0, s0, r1, s1, ...) - bcast_shape = list(result.shape) - for i, r in enumerate(reps): - bcast_shape[i * 2] = r - result = broadcast_to(result, tuple(bcast_shape)) - - # Merge pairs: (r0*s0, r1*s1, ...) - final_shape = tuple(r * s for r, s in zip(reps, x_shape)) - return reshape(result, final_shape) - +roll = Op("roll") # ----------------------------------------------------------------------------- -# roll +# copyto (deprecated for HLO due to in-place semantics) # ----------------------------------------------------------------------------- -roll = Op("roll") - - -@roll.impl("hlo") -def _roll_hlo(x, shift, axis=None): - """Roll tensor elements along a given axis.""" - x_shape = x.shape - ndim = len(x_shape) - - if axis is None: - # Flatten, roll, reshape back - total = int(np.prod(x_shape)) - flat = reshape(x, (total,)) - rolled = _roll_single_axis(flat, shift, 0) - return reshape(rolled, x_shape) - - if isinstance(shift, (list, tuple)): - if not isinstance(axis, (list, tuple)): - raise ValueError("If shift is a tuple, axis must also be a tuple") - result = x - for s, a in zip(shift, axis): - result = _roll_single_axis(result, s, a if a >= 0 else a + ndim) - return result - - if axis < 0: - axis += ndim - return _roll_single_axis(x, shift, axis) - - -def _roll_single_axis(x, shift, axis): - """Roll along a single axis: concatenate(x[shift:], x[:shift]).""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_bt = x.backend_tensor - else: - x_bt = x - - axis_size = x_bt.shape[axis] - ndim = len(x_bt.shape) - - # Normalize shift to positive - shift = shift % axis_size - if shift == 0: - return NKIPyTensorRef(x_bt) if not isinstance(x, NKIPyTensorRef) else x - - # Split point: we want elements from (axis_size - shift) onward first - split_point = axis_size - shift - - # First slice: x[split_point:] along axis - start1 = [0] * ndim - limit1 = list(x_bt.shape) - start1[axis] = split_point - shape1 = list(x_bt.shape) - shape1[axis] = shift - - t1 = ctx.build_op( - "slice", - [x_bt], - tuple(shape1), - x_bt.dtype, - {"start_indices": start1, "limit_indices": limit1, "strides": [1] * ndim}, - ) - - # Second slice: x[:split_point] along axis - start0 = [0] * ndim - limit0 = list(x_bt.shape) - limit0[axis] = split_point - shape0 = list(x_bt.shape) - shape0[axis] = split_point - - t0 = ctx.build_op( - "slice", - [x_bt], - tuple(shape0), - x_bt.dtype, - {"start_indices": start0, "limit_indices": limit0, "strides": [1] * ndim}, - ) - - return concatenate([NKIPyTensorRef(t1), NKIPyTensorRef(t0)], axis=axis) - - -def _slice_single(x, dim, index): - """Slice a single element along a dimension, removing that dim.""" - from nkipy.core.backend.hlo import get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x_bt = x.backend_tensor - else: - x_bt = x - - ndim = len(x_bt.shape) - start_indices = [0] * ndim - limit_indices = list(x_bt.shape) - strides_list = [1] * ndim - - start_indices[dim] = index - limit_indices[dim] = index + 1 - - slice_shape = list(x_bt.shape) - slice_shape[dim] = 1 - - sliced = ctx.build_op( - "slice", - [x_bt], - tuple(slice_shape), - x_bt.dtype, - { - "start_indices": start_indices, - "limit_indices": limit_indices, - "strides": strides_list, - }, - ) - - # Remove the sliced dimension - result_shape = tuple(s for i, s in enumerate(x_bt.shape) if i != dim) - result = ctx.build_op("reshape", [sliced], result_shape, x_bt.dtype) - return NKIPyTensorRef(result) +copyto = Op("copyto") diff --git a/nkipy/src/nkipy/core/ops/unary.py b/nkipy/src/nkipy/core/ops/unary.py index a6da7a5..27d3479 100644 --- a/nkipy/src/nkipy/core/ops/unary.py +++ b/nkipy/src/nkipy/core/ops/unary.py @@ -2,92 +2,39 @@ # SPDX-License-Identifier: Apache-2.0 """Unary operations: abs, exp, log, sqrt, sin, cos, etc.""" -import numpy as np - from nkipy.core.ops._registry import Op -# ============================================================================= -# HLO Implementation -# ============================================================================= - - -def _build_unary_hlo(x, np_op, out=None, dtype=None): - """Build a unary HLO operation.""" - from nkipy.core.backend.hlo import as_hlo_tensor, get_hlo_context - from nkipy.core.tensor import NKIPyTensorRef - - ctx = get_hlo_context() - - if isinstance(x, NKIPyTensorRef): - x = x.backend_tensor - - # Special handling for arctan: use atan2(x, 1) - if np_op == np.arctan: - one_tensor = as_hlo_tensor(ctx, 1.0, x.dtype) - if x.shape: - one_tensor = ctx.build_op( - "broadcast", - [one_tensor], - x.shape, - x.dtype, - {"broadcast_dimensions": []}, - ) - result_tensor = ctx.build_op("atan2", [x, one_tensor], x.shape, x.dtype) - return NKIPyTensorRef(result_tensor) - - # Map numpy ops to HLO opcodes - op_map = { - np.abs: "abs", - np.exp: "exponential", - np.log: "log", - np.sqrt: "sqrt", - np.sin: "sine", - np.cos: "cosine", - np.tanh: "tanh", - np.negative: "negate", - np.ceil: "ceil", - np.floor: "floor", - np.sign: "sign", - np.bitwise_not: "not", - np.invert: "not", - } - - hlo_op = op_map.get( - np_op, np_op.__name__ if hasattr(np_op, "__name__") else str(np_op) - ) - result_tensor = ctx.build_op(hlo_op, [x], x.shape, x.dtype) - - return NKIPyTensorRef(result_tensor) - - # ----------------------------------------------------------------------------- -# Factory function for simple unary ops +# Math operations (primitives) # ----------------------------------------------------------------------------- -def _make_unary_op(name: str, np_op) -> Op: - """Create a unary Op with IR and HLO implementations.""" - op = Op(name) - - @op.impl("hlo") - def _impl_hlo(x, out=None, dtype=None): - return _build_unary_hlo(x, np_op, out=out, dtype=dtype) - - return op +abs = Op("abs") +exp = Op("exp") +log = Op("log") +sqrt = Op("sqrt") +negative = Op("negative") +sin = Op("sin") +cos = Op("cos") +arctan = Op("arctan") +tanh = Op("tanh") +ceil = Op("ceil") +floor = Op("floor") +sign = Op("sign") +# ----------------------------------------------------------------------------- +# Bitwise/logical operations (primitives) +# ----------------------------------------------------------------------------- +invert = Op("invert") +bitwise_not = Op("bitwise_not") # ----------------------------------------------------------------------------- -# Math operations +# Composed unary ops — built from other dispatched ops # ----------------------------------------------------------------------------- -abs = _make_unary_op("abs", np.abs) -exp = _make_unary_op("exp", np.exp) -log = _make_unary_op("log", np.log) -sqrt = _make_unary_op("sqrt", np.sqrt) -negative = _make_unary_op("negative", np.negative) reciprocal = Op("reciprocal") -@reciprocal.impl("hlo") -def _reciprocal_hlo(x, out=None, dtype=None): +@reciprocal.composed_impl +def _reciprocal(x, out=None, dtype=None): from nkipy.core.ops.binary import divide return divide(1.0, x) @@ -96,53 +43,38 @@ def _reciprocal_hlo(x, out=None, dtype=None): square = Op("square") -@square.impl("hlo") -def _square_hlo(x, out=None, dtype=None): +@square.composed_impl +def _square(x, out=None, dtype=None): from nkipy.core.ops.binary import multiply return multiply(x, x) -# ----------------------------------------------------------------------------- -# Trigonometric operations -# ----------------------------------------------------------------------------- -sin = _make_unary_op("sin", np.sin) -cos = _make_unary_op("cos", np.cos) - tan = Op("tan") -@tan.impl("hlo") -def _tan_hlo(x, out=None, dtype=None): +@tan.composed_impl +def _tan(x, out=None, dtype=None): from nkipy.core.ops.binary import divide return divide(sin(x), cos(x)) -arctan = _make_unary_op("arctan", np.arctan) -tanh = _make_unary_op("tanh", np.tanh) +logical_not = Op("logical_not") -# ----------------------------------------------------------------------------- -# Rounding operations -# ----------------------------------------------------------------------------- -ceil = _make_unary_op("ceil", np.ceil) -floor = _make_unary_op("floor", np.floor) -sign = _make_unary_op("sign", np.sign) -# rint: round to nearest even (banker's rounding) - simplified using ops -rint = Op("rint") +@logical_not.composed_impl +def _logical_not(x, out=None, dtype=None): + from nkipy.core.ops.binary import equal + + return equal(x, 0) + +rint = Op("rint") -@rint.impl("hlo") -def _rint_hlo(x, out=None, dtype=None): - """Round to nearest even integer (banker's rounding). - Logic: - - frac = x - floor(x) - - If frac > 0.5: round up - - If frac == 0.5 AND floor is odd: round up (to make even) - - Otherwise: use floor - """ +@rint.composed_impl +def _rint(x, out=None, dtype=None): from nkipy.core.ops.binary import ( add, divide, @@ -159,61 +91,36 @@ def _rint_hlo(x, out=None, dtype=None): frac = subtract(x, floor_val) ceil_val = add(floor_val, 1.0) - # Check if frac > 0.5 frac_gt_half = greater(frac, 0.5) - # Check if floor is odd: floor % 2 != 0 - # floor % 2 = floor - 2 * floor(floor / 2) floor_div_two = floor(divide(floor_val, 2.0)) floor_mod_two = subtract(floor_val, multiply(floor_div_two, 2.0)) is_odd = not_equal(floor_mod_two, 0.0) - # Check if frac == 0.5 AND floor is odd frac_eq_half = equal(frac, 0.5) should_round_up_for_even = logical_and(frac_eq_half, is_odd) - # First select: if frac > 0.5, use ceil, else use floor result = where(frac_gt_half, ceil_val, floor_val) - # Second select: if frac == 0.5 AND odd, use ceil (to make even) return where(should_round_up_for_even, ceil_val, result) trunc = Op("trunc") -@trunc.impl("hlo") -def _trunc_hlo(x, out=None, dtype=None): +@trunc.composed_impl +def _trunc(x, out=None, dtype=None): from nkipy.core.ops.binary import greater_equal from nkipy.core.ops.indexing import where return where(greater_equal(x, 0), floor(x), ceil(x)) -# ----------------------------------------------------------------------------- -# Bitwise/logical operations -# ----------------------------------------------------------------------------- -invert = _make_unary_op("invert", np.invert) -bitwise_not = _make_unary_op("bitwise_not", np.bitwise_not) - -logical_not = Op("logical_not") - - -@logical_not.impl("hlo") -def _logical_not_hlo(x, out=None, dtype=None): - from nkipy.core.ops.binary import equal - - return equal(x, 0) - - -# ----------------------------------------------------------------------------- -# clip: clamp values between a_min and a_max -# ----------------------------------------------------------------------------- clip = Op("clip") -@clip.impl("hlo") -def _clip_hlo(x, a_min=None, a_max=None, out=None): +@clip.composed_impl +def _clip(x, a_min=None, a_max=None, out=None): from nkipy.core.ops.binary import maximum, minimum result = x @@ -224,54 +131,43 @@ def _clip_hlo(x, a_min=None, a_max=None, out=None): return result -# ----------------------------------------------------------------------------- -# log1p: log(1 + x) -# ----------------------------------------------------------------------------- log1p = Op("log1p") -@log1p.impl("hlo") -def _log1p_hlo(x, out=None, dtype=None): - """log(1+x). Note: uses log(1+x) decomposition; loses precision for |x| << 1.""" +@log1p.composed_impl +def _log1p(x, out=None, dtype=None): from nkipy.core.ops.binary import add return log(add(x, 1.0)) -# ----------------------------------------------------------------------------- -# log2: log(x) / log(2) -# ----------------------------------------------------------------------------- log2 = Op("log2") -@log2.impl("hlo") -def _log2_hlo(x, out=None, dtype=None): +@log2.composed_impl +def _log2(x, out=None, dtype=None): + import numpy as np + from nkipy.core.ops.binary import divide return divide(log(x), float(np.log(2.0))) -# ----------------------------------------------------------------------------- -# expm1: exp(x) - 1 -# ----------------------------------------------------------------------------- expm1 = Op("expm1") -@expm1.impl("hlo") -def _expm1_hlo(x, out=None, dtype=None): +@expm1.composed_impl +def _expm1(x, out=None, dtype=None): from nkipy.core.ops.binary import subtract return subtract(exp(x), 1.0) -# ----------------------------------------------------------------------------- -# round: round to given number of decimals -# ----------------------------------------------------------------------------- round_ = Op("round") -@round_.impl("hlo") -def _round_hlo(x, decimals=0, out=None): +@round_.composed_impl +def _round(x, decimals=0, out=None): if decimals == 0: return rint(x) from nkipy.core.ops.binary import divide, multiply @@ -280,27 +176,21 @@ def _round_hlo(x, decimals=0, out=None): return divide(rint(multiply(x, scale)), scale) -# ----------------------------------------------------------------------------- -# isnan: x != x (NaN is the only value not equal to itself) -# ----------------------------------------------------------------------------- isnan = Op("isnan") -@isnan.impl("hlo") -def _isnan_hlo(x, out=None, dtype=None): +@isnan.composed_impl +def _isnan(x, out=None, dtype=None): from nkipy.core.ops.binary import not_equal return not_equal(x, x) -# ----------------------------------------------------------------------------- -# isfinite: (x - x) == 0 (both NaN and Inf fail this) -# ----------------------------------------------------------------------------- isfinite = Op("isfinite") -@isfinite.impl("hlo") -def _isfinite_hlo(x, out=None, dtype=None): +@isfinite.composed_impl +def _isfinite(x, out=None, dtype=None): from nkipy.core.ops.binary import equal, subtract return equal(subtract(x, x), 0.0) diff --git a/nkipy/src/nkipy/core/trace.py b/nkipy/src/nkipy/core/trace.py index 88b908c..2d44a4b 100644 --- a/nkipy/src/nkipy/core/trace.py +++ b/nkipy/src/nkipy/core/trace.py @@ -8,9 +8,8 @@ import numpy as np from nkipy.core._numpy_dispatch import register_all_numpy_apis -from nkipy.core.backend import tracing +from nkipy.core.backend import AliasInfo, tracing from nkipy.core.backend.hlo import ( - AliasInfo, HLOModule, HLOTraceContext, get_hlo_context, @@ -47,10 +46,48 @@ def _sanitize_array_dtype(arr: np.ndarray, name: str = "") -> np.ndarray: return arr.astype(target) +def _convert_args(sig, boundargs, convert_arg): + """Convert bound arguments to traced tensor refs. + + Shared by both HLO and nkigen specialization paths. + + Each argument is passed through *convert_arg* which replaces ndarrays + with backend-specific tensor refs and returns non-tensor values unchanged. + VAR_POSITIONAL and VAR_KEYWORD arguments are expanded so each element is + converted individually. + + Args: + sig: The function's inspect.Signature. + boundargs: The BoundArguments (already defaulted). + convert_arg: ``(name, arg) -> converted_value``. + + Returns: + ``(converted_args, converted_kwargs)`` ready to call the kernel. + """ + converted_args = [] + converted_kwargs = {} + + for name, arg in boundargs.arguments.items(): + param = sig.parameters[name] + + if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): + converted_args.append(convert_arg(name, arg)) + elif param.kind == param.KEYWORD_ONLY: + converted_kwargs[name] = convert_arg(name, arg) + elif param.kind == param.VAR_POSITIONAL: + for item in arg: + converted_args.append(convert_arg(name, item)) + elif param.kind == param.VAR_KEYWORD: + for k, v in arg.items(): + converted_kwargs[k] = convert_arg(k, v) + + return converted_args, converted_kwargs + + class NKIPyKernel: """Simplified kernel wrapper for NKIPy tracing""" - def __init__(self, func, backend, **kwargs): + def __init__(self, func, backend): self.func = func self.backend = backend self._code = None @@ -65,102 +102,81 @@ def __repr__(self): def specialize(self, *args, **kwargs): if self.backend == "hlo": return self._specialize_hlo(*args, **kwargs) + elif self.backend == "nkigen": + return self._specialize_nkigen(*args, **kwargs) elif self.backend == "cpu": - print("CPU backend does not require specialization") + warnings.warn( + "CPU backend does not require specialization", stacklevel=2 + ) return else: raise ValueError(f"Unknown backend {self.backend}") def _create_parameter_hlo(self, shape, dtype, name=""): - """Create an HLO parameter tensor""" + """Create an HLO parameter tensor.""" ctx = get_hlo_context() hlo_tensor = ctx.module.add_parameter(shape, dtype, name=name) return NKIPyTensorRef(hlo_tensor, name=name) def _specialize_hlo(self, *args, **kwargs): - """Trace the kernel with specific arguments""" + """Trace the kernel with specific arguments.""" + from nkipy.core.ops._register_hlo import register_all_hlo_impls + + register_all_hlo_impls() code = HLOModule(name=self.func.__name__) with tracing(HLOTraceContext(code)): - # Bind arguments sig = inspect.signature(self.func) boundargs = sig.bind(*args, **kwargs) boundargs.apply_defaults() - # Convert numpy arrays to tensor references - converted_args = [] - converted_kwargs = {} - # Track parameter tensor refs: list of (param_name, tensor_ref) for arrays param_tensor_refs = [] - for name, arg in boundargs.arguments.items(): - param = sig.parameters[name] - + def _make_hlo_ref(name, arg): if isinstance(arg, np.ndarray): arg = _sanitize_array_dtype(arg, name) tensor_ref = self._create_parameter_hlo(arg.shape, arg.dtype, name) tensor_ref._original_parameter = tensor_ref.backend_tensor - converted_value = tensor_ref - param_tensor_refs.append((name, tensor_ref)) - else: - converted_value = arg - - # Determine if this should be positional or keyword - if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): - converted_args.append(converted_value) - elif param.kind == param.KEYWORD_ONLY: - converted_kwargs[name] = converted_value - elif param.kind == param.VAR_POSITIONAL: - if isinstance(arg, (list, tuple)): - for item in arg: - if isinstance(item, np.ndarray): - item = _sanitize_array_dtype(item, f"{name}_item") - converted_args.append( - self._create_parameter_hlo( - item.shape, item.dtype, f"{name}_item" - ) - ) - else: - converted_args.append(item) - else: - converted_args.append(converted_value) - elif param.kind == param.VAR_KEYWORD: - if isinstance(arg, dict): - for k, v in arg.items(): - if isinstance(v, np.ndarray): - v = _sanitize_array_dtype(v, k) - converted_kwargs[k] = self._create_parameter_hlo( - v.shape, v.dtype, k - ) - else: - converted_kwargs[k] = v - - # Execute function + param_index = tensor_ref.backend_tensor.parameter_id + param_tensor_refs.append((name, param_index, tensor_ref)) + return tensor_ref + return arg + + converted_args, converted_kwargs = _convert_args( + sig, boundargs, _make_hlo_ref + ) + ret = self.func(*converted_args, **converted_kwargs) - # Mark outputs self._mark_hlo_outputs(code, ret, param_tensor_refs) self._code = code return code - def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs): - """Mark HLO outputs using mutation tracking. + @staticmethod + def _detect_mutations(ret, param_tensor_refs): + """Detect mutated parameters and auto-append them to the return list. - Detects aliasing by checking which parameter tensor refs were mutated - (had __setitem__ called on them) during kernel execution. + Checks which parameter tensor refs were mutated (had __setitem__ + called on them) during kernel execution. Mutated parameters that + the user did not return are appended automatically so the backend + can compile them as outputs. Note: Only direct mutations on the original parameter tensor refs are detected. View aliasing (e.g. ``b = a[0]; b[x] = y``) is not tracked - because ``__getitem__`` creates a new tensor ref with no parent link. + because ``__getitem__`` creates a new NKIPyTensorRef with no parent link. Args: - code: The HLOModule being built - ret: The return value(s) from the kernel function - param_tensor_refs: List of (param_name, tensor_ref) for array parameters + ret: The return value(s) from the kernel function. + param_tensor_refs: List of (param_name, param_index, tensor_ref). + + Returns: + ``(ret, user_return_len, alias_map)`` where *ret* is the + (possibly extended) list of outputs, *user_return_len* is the + original count before auto-appending, and *alias_map* is + ``{output_index: (param_name, param_index)}``. """ - # Normalize return value to a list (may be None for mutation-only kernels) if ret is None: ret = [] elif not isinstance(ret, (list, tuple)): @@ -168,62 +184,64 @@ def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs): ret = list(ret) user_return_len = len(ret) - ctx = get_hlo_context() - - # Step 1: For each mutated param, rename HLO parameter. - # Check if user returned it; if not, auto-append to output list. - aliased_return_positions = {} # output_index -> (param_name, param_index) - for name, tr in param_tensor_refs: + alias_map = {} + for name, pidx, tr in param_tensor_refs: if not tr._is_mutated: continue - - # Rename HLO parameter for compiler convention - param_index = None - for hlo_param in code.parameters: - if hlo_param.name == name: - hlo_param.name = f"{name}.must_alias_input" - param_index = hlo_param.parameter_id - break - - if param_index is None: - raise RuntimeError( - f"Mutated parameter '{name}' not found in HLO parameters" - ) - - # Check if this mutated param is in the user's return values (identity check) found_at = None for i, r in enumerate(ret): if isinstance(r, NKIPyTensorRef) and r is tr: found_at = i break - if found_at is not None: - aliased_return_positions[found_at] = (name, param_index) + alias_map[found_at] = (name, pidx) else: - # Auto-append to output list ret.append(tr) - aliased_return_positions[len(ret) - 1] = (name, param_index) + alias_map[len(ret) - 1] = (name, pidx) + + return ret, user_return_len, alias_map - # Step 2: Insert explicit copy for unmutated pass-through outputs. + def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs): + """Mark HLO outputs using mutation tracking. + + Args: + code: The HLOModule being built + ret: The return value(s) from the kernel function + param_tensor_refs: List of (param_name, param_index, tensor_ref) + """ + ret, user_return_len, alias_map = self._detect_mutations( + ret, param_tensor_refs + ) + + ctx = get_hlo_context() + + # Rename mutated HLO parameters for compiler convention + for _, (param_name, _) in alias_map.items(): + for hlo_param in code.parameters: + if hlo_param.name == param_name: + hlo_param.name = f"{param_name}.must_alias_input" + break + + # Insert explicit copy for unmutated pass-through outputs. # The Neuron compiler cannot handle outputs that are raw parameter # references because inputs and outputs occupy separate memory regions. for i, r in enumerate(ret): if not isinstance(r, NKIPyTensorRef): continue - if i in aliased_return_positions: + if i in alias_map: continue bt = r.backend_tensor if bt.is_parameter: copy_tensor = ctx.build_op("copy", [bt], bt.shape, bt.dtype) ret[i] = NKIPyTensorRef(copy_tensor, name="") - # Step 3: Assign output names and build AliasInfo list + # Assign output names and build AliasInfo list for idx, r in enumerate(ret): if not isinstance(r, NKIPyTensorRef): raise RuntimeError(f"Unexpected return value type: {type(r)}") - if idx in aliased_return_positions: - param_name, param_index = aliased_return_positions[idx] + if idx in alias_map: + param_name, param_index = alias_map[idx] code.aliases.append( AliasInfo( output_index=idx, @@ -242,9 +260,124 @@ def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs): result_tensors = [r.backend_tensor for r in ret] code.set_results(result_tensors) + def _specialize_nkigen(self, *args, **kwargs): + """Trace the kernel to MLIR linalg/tensor IR via the nkigen backend.""" + from nkipy.core.backend.nkigen import NkiGenTraceContext + from nkipy.core.ops._register_nkigen import register_all_nkigen_impls + + register_all_nkigen_impls() + + kctx = NkiGenTraceContext() + + sig = inspect.signature(self.func) + boundargs = sig.bind(*args, **kwargs) + boundargs.apply_defaults() + + arg_shapes = [] + arg_dtypes = [] + arg_names = [] + + def _collect_array(name, arg): + arg = _sanitize_array_dtype(arg, name) + arg_shapes.append(arg.shape) + arg_dtypes.append(arg.dtype) + arg_names.append(name) + return arg + + for name, arg in boundargs.arguments.items(): + param = sig.parameters[name] + if param.kind == param.VAR_POSITIONAL: + sanitized = [] + for item in arg: + sanitized.append( + _collect_array(name, item) + if isinstance(item, np.ndarray) + else item + ) + boundargs.arguments[name] = tuple(sanitized) + elif param.kind == param.VAR_KEYWORD: + for k, v in arg.items(): + if isinstance(v, np.ndarray): + arg[k] = _collect_array(k, v) + elif isinstance(arg, np.ndarray): + arg = _collect_array(name, arg) + boundargs.arguments[name] = arg + + param_tensors = kctx._begin_function(self.func.__name__, arg_shapes, arg_dtypes) + for pt, name in zip(param_tensors, arg_names): + pt.name = name + + param_tensor_refs = [] + + with tracing(kctx): + param_idx = 0 + + def _make_kg_ref(name, arg): + nonlocal param_idx + if isinstance(arg, np.ndarray): + ref = NKIPyTensorRef(param_tensors[param_idx], name=name) + param_tensor_refs.append((name, param_idx, ref)) + param_idx += 1 + return ref + return arg + + converted_args, converted_kwargs = _convert_args( + sig, boundargs, _make_kg_ref + ) + + raw_ret = self.func(*converted_args, **converted_kwargs) + + ret, user_return_len, alias_map = self._detect_mutations( + raw_ret, param_tensor_refs + ) + + result_kg_tensors = [] + for r in ret: + if isinstance(r, NKIPyTensorRef): + result_kg_tensors.append(r.backend_tensor) + else: + raise RuntimeError(f"Unexpected return type: {type(r)}") + + kctx._finish_function(result_kg_tensors) + + kctx._run_canonicalize() + + mlir_text = kctx._get_ir_text() + kctx._cleanup() + + # BIR emission assigns "in_tensor_N" for inputs and "output" / + # "output_N" for outputs regardless of caller-provided names, so IR + # input specs must match what the NEFF will contain. + num_outputs = len(result_kg_tensors) + input_info = [ + (f"in_tensor_{i}", shape, dtype) + for i, (shape, dtype) in enumerate(zip(arg_shapes, arg_dtypes)) + ] + output_info = [ + ( + "output" if num_outputs == 1 else f"output_{i}", + t.shape, + t.dtype, + ) + for i, t in enumerate(result_kg_tensors) + ] + + from nkipy.core.backend.nkigen import NkiGenIR + + self._code = NkiGenIR( + mlir_text=mlir_text, + func_name=self.func.__name__, + input_specs=input_info, + output_specs=output_info, + alias_map=alias_map, + user_return_len=user_return_len, + original_param_names=arg_names, + ) + return self._code + @classmethod - def trace(cls, func=None, backend="hlo", **kwargs): - """Decorator to create traced kernel""" + def trace(cls, func=None, backend="hlo"): + """Decorator to create traced kernel.""" if func is None: - return lambda f: cls(f, backend, **kwargs) - return cls(func, backend, **kwargs) + return lambda f: cls(f, backend) + return cls(func, backend) diff --git a/nkipy/src/nkipy/runtime/baremetal_executor.py b/nkipy/src/nkipy/runtime/baremetal_executor.py index 3a01937..15e9d08 100644 --- a/nkipy/src/nkipy/runtime/baremetal_executor.py +++ b/nkipy/src/nkipy/runtime/baremetal_executor.py @@ -9,6 +9,7 @@ import numpy as np +from nkipy.core.backend import prepare_io_mapping from nkipy.runtime.device_kernel import DeviceKernel from nkipy.runtime.device_tensor import DeviceTensor @@ -71,26 +72,20 @@ def _prepare_io_tensors( } # Prepare inputs using DeviceTensor - inputs = {} - for intensor in compiled_kernel.ir.inputs: - real_name = ( - intensor.name.split(".must_alias_input")[0] - if ".must_alias_input" in intensor.name - else intensor.name - ) - np_tensor = original_inputs.get(real_name, boundargs.arguments[real_name]) - inputs[intensor.name] = DeviceTensor.from_numpy(np_tensor) + ir = compiled_kernel.ir + input_arrays, alias_input_names = prepare_io_mapping(ir.inputs, ir.aliases, original_inputs) + inputs = { + name: DeviceTensor.from_numpy(arr) + for name, arr in input_arrays.items() + } # Prepare outputs — aliased outputs share the input device buffer outputs = device_kernel.allocate_output_tensors() outputs_dict = {t.name: t for t in outputs} - alias_by_output = {a.output_index: a for a in compiled_kernel.ir.aliases} - for i, outtensor in enumerate(compiled_kernel.ir.outputs): - if i in alias_by_output: - alias = alias_by_output[i] - input_name = f"{alias.param_name}.must_alias_input" - outputs_dict[outtensor.name] = inputs[input_name] + for i, outtensor in enumerate(ir.outputs): + if i in alias_input_names: + outputs_dict[outtensor.name] = inputs[alias_input_names[i]] return inputs, outputs_dict, original_inputs diff --git a/nkipy/src/nkipy/runtime/decorators.py b/nkipy/src/nkipy/runtime/decorators.py index 65fd000..5d3193f 100644 --- a/nkipy/src/nkipy/runtime/decorators.py +++ b/nkipy/src/nkipy/runtime/decorators.py @@ -10,6 +10,7 @@ def baremetal_jit( kernel_func=None, *, + backend="hlo", additional_compiler_args="", target=compile.CompilationTarget.DEFAULT, ): @@ -21,6 +22,7 @@ def baremetal_jit( Args: kernel_func: The kernel function to decorate (when used without parentheses) + backend: Compilation backend ("hlo" or "nkigen") additional_compiler_args: Additional arguments to pass to the compiler target: Compilation target (default: CompilationTarget.DEFAULT) @@ -35,8 +37,8 @@ def my_kernel(A, B): # Compiles on first call with this signature result = my_kernel(input_a, input_b) - # Or with compiler args: - @baremetal_jit(additional_compiler_args="--lnc 1") + # Or with nkigen backend: + @baremetal_jit(backend="nkigen") def my_kernel(A, B): return A @ B """ @@ -45,7 +47,7 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): # Trace the kernel - traced_kernel = trace(func) + traced_kernel = trace(func, backend=backend) # Use baremetal_run_traced_kernel for execution return baremetal_run_traced_kernel( traced_kernel, diff --git a/nkipy/src/nkipy/runtime/device_kernel.py b/nkipy/src/nkipy/runtime/device_kernel.py index 2960fe9..2737fd2 100644 --- a/nkipy/src/nkipy/runtime/device_kernel.py +++ b/nkipy/src/nkipy/runtime/device_kernel.py @@ -1,14 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import atexit -import hashlib import os import shutil import time import types from nkipy.core import compile -from nkipy.core.backend.hlo import HLOModule from nkipy.core.compile import CompilationTarget, _get_build_dir, compile_to_neff, trace from nkipy.core.logger import get_logger from nkipy.core.trace import NKIPyKernel @@ -35,24 +33,6 @@ def _cleanup_kernels(): atexit.register(_cleanup_kernels) -def _hlo_content_hash(hlo_module: HLOModule, compiler_args: str) -> str: - """Compute a content hash from the HLO protobuf and compiler args. - - Hashing the HLO (instead of source code) ensures that different input - shapes/dtypes produce different cache entries, even when the kernel - source is identical. - - The HLO proto uses only ``repeated`` fields (no ``map`` fields), so - ``SerializeToString()`` is deterministic for the same computation graph. - """ - h = hashlib.sha256() - - # TODO: this SerializeToString can be slow for large HLO - h.update(hlo_module.to_proto().SerializeToString()) - h.update(compiler_args.encode("utf-8")) - return h.hexdigest()[:12] - - def _is_distributed() -> bool: """Check if running in a multi-worker torch.distributed setting.""" return ( @@ -255,12 +235,7 @@ def _trace_and_compile( traced_kernel.specialize(*numpy_args, **numpy_kwargs) - # Compute content hash from HLO - hlo_module = traced_kernel._code - if not isinstance(hlo_module, HLOModule): - raise NotImplementedError("Only HLOModule is supported for content hashing") - - content_hash = _hlo_content_hash(hlo_module, compiler_args) + content_hash = traced_kernel._code.content_hash(compiler_args) cache_key = f"{name}_{content_hash}" # Determine output paths diff --git a/nkipy/src/nkipy/runtime/execute.py b/nkipy/src/nkipy/runtime/execute.py index da20063..a4ca917 100644 --- a/nkipy/src/nkipy/runtime/execute.py +++ b/nkipy/src/nkipy/runtime/execute.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 """Execution wrappers for NKIPy kernels""" +from __future__ import annotations + import inspect import os import shutil @@ -9,7 +11,7 @@ import numpy as np from nkipy.core import compile -from nkipy.core.trace import _sanitize_array_dtype +from nkipy.core.backend import ComputationIR, prepare_io_mapping try: from nkipy.runtime.device_kernel import DeviceKernel @@ -30,39 +32,21 @@ def _compile_kernel( ): """Specialize and compile a traced kernel to NEFF. - Returns (neff_path, kernel_name, ir, boundargs). + Returns (neff_path, kernel_name, ir, original_inputs). """ - # Sanitize unsupported dtypes (float64/int64/uint64) before tracing - args = tuple( - _sanitize_array_dtype(a, f"arg{i}") if isinstance(a, np.ndarray) else a - for i, a in enumerate(args) - ) - kwargs = { - k: _sanitize_array_dtype(v, k) if isinstance(v, np.ndarray) else v - for k, v in kwargs.items() - } - - # Trace the kernel with the provided arguments kernel.specialize(*args, **kwargs) ir = kernel._code - # Bind arguments for input/output mapping sig = inspect.signature(kernel.func) boundargs = sig.bind(*args, **kwargs) boundargs.apply_defaults() - # Save original input arrays before output allocation may overwrite them original_inputs = { name: arr for name, arr in boundargs.arguments.items() if isinstance(arr, np.ndarray) } - # Allocate output tensors based on IR outputs - for outtensor in ir.outputs: - output_array = np.empty(outtensor.shape, dtype=outtensor.dtype) - boundargs.arguments[outtensor.name] = output_array - name = kernel.__name__ build_dir = artifacts_dir if artifacts_dir else f"{compile._get_build_dir()}/{name}" @@ -88,10 +72,10 @@ def _compile_kernel( target=target, ) - return neff, name, ir, boundargs, original_inputs + return neff, name, ir, original_inputs -def _execute_neff(neff, name, ir, boundargs, original_inputs, save_trace=False): +def _execute_neff(neff, name, ir: ComputationIR, original_inputs, save_trace=False): """Load a compiled NEFF and run it on hardware. Returns output numpy array(s), with auto-aliased outputs filtered out. @@ -104,46 +88,36 @@ def _execute_neff(neff, name, ir, boundargs, original_inputs, save_trace=False): device_kernel = DeviceKernel.load_from_neff(neff, name) - # Build alias lookup: output_index -> AliasInfo - alias_by_output = {a.output_index: a for a in ir.aliases} - - device_inputs = {} - for intensor in ir.inputs: - if "must_alias_input" in intensor.name: - base_name = intensor.name.split(".must_alias_input")[0] - np_tensor = original_inputs[base_name] - else: - np_tensor = boundargs.arguments[intensor.name] - device_inputs[intensor.name] = DeviceTensor.from_numpy(np_tensor) + ir_inputs = ir.inputs + input_arrays, alias_input_names = prepare_io_mapping(ir_inputs, ir.aliases, original_inputs) + device_inputs = { + input_name: DeviceTensor.from_numpy(arr) + for input_name, arr in input_arrays.items() + } device_outputs = {} for i, outtensor in enumerate(ir.outputs): - if i in alias_by_output: - # Aliased output shares the same device buffer as the input - alias = alias_by_output[i] - input_name = f"{alias.param_name}.must_alias_input" - device_outputs[outtensor.name] = device_inputs[input_name] + if i in alias_input_names: + device_outputs[outtensor.name] = device_inputs[alias_input_names[i]] else: np_output = np.zeros(outtensor.shape, dtype=outtensor.dtype) device_outputs[outtensor.name] = DeviceTensor.from_numpy(np_output) device_kernel(inputs=device_inputs, outputs=device_outputs, save_trace=save_trace) + output_arrays = {} + alias_by_output = {a.output_index: a for a in ir.aliases} for i, outtensor in enumerate(ir.outputs): result = device_outputs[outtensor.name].numpy() if i in alias_by_output: alias = alias_by_output[i] np.copyto(dst=original_inputs[alias.param_name], src=result) - # Point boundargs at the same array so the return logic can find it - boundargs.arguments[outtensor.name] = original_inputs[alias.param_name] - else: - dst = boundargs.arguments[outtensor.name] - np.copyto(dst=dst, src=result) + output_arrays[outtensor.name] = result # Filter out auto-aliased outputs (not user-returned) auto_indices = ir.auto_aliased_indices user_outputs = [ - boundargs.arguments[out.name] + output_arrays[out.name] for i, out in enumerate(ir.outputs) if i not in auto_indices ] @@ -165,7 +139,7 @@ def baremetal_run_traced_kernel( **kwargs, ): """Compile and run a traced kernel on hardware.""" - neff, name, ir, boundargs, original_inputs = _compile_kernel( + neff, name, ir, original_inputs = _compile_kernel( kernel, *args, artifacts_dir=artifacts_dir, @@ -174,5 +148,5 @@ def baremetal_run_traced_kernel( **kwargs, ) return _execute_neff( - neff, name, ir, boundargs, original_inputs, save_trace=save_trace + neff, name, ir, original_inputs, save_trace=save_trace ) diff --git a/nkipy/src/nkipy/tools/kernel_agent/executor.py b/nkipy/src/nkipy/tools/kernel_agent/executor.py index 74f4f62..36a1f6b 100644 --- a/nkipy/src/nkipy/tools/kernel_agent/executor.py +++ b/nkipy/src/nkipy/tools/kernel_agent/executor.py @@ -80,7 +80,7 @@ def run_kernel( from nkipy.runtime.execute import _compile_kernel traced = NKIPyKernel.trace(kernel_fn) - neff, kname, ir, boundargs, original_inputs = _compile_kernel( + neff, kname, ir, original_inputs = _compile_kernel( traced, *args, artifacts_dir=artifacts_dir ) result.compile = StageResult(success=True) @@ -93,7 +93,7 @@ def run_kernel( try: from nkipy.runtime.execute import _execute_neff - out = _execute_neff(neff, kname, ir, boundargs, original_inputs) + out = _execute_neff(neff, kname, ir, original_inputs) result.hardware = StageResult(success=True, output=np.asarray(out)) except Exception as e: result.hardware = StageResult(success=False, error=str(e)) diff --git a/tests/test_nkigen_numerical.py b/tests/test_nkigen_numerical.py new file mode 100644 index 0000000..fa08592 --- /dev/null +++ b/tests/test_nkigen_numerical.py @@ -0,0 +1,130 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Numerical correctness and full-pipeline tests for the nkigen backend. + +Two levels of verification: +1. LLVM JIT smoke test — trace via nkipy, run through MLIR passes, execute + via LLVM JIT to verify numerical correctness without hardware. +2. NEFF compilation — trace via nkipy with knob() annotations, compile all + the way to NEFF to catch pass-pipeline and mem_space enum issues. + +Requires nkigen (pass pipeline, LLVM JIT infrastructure). +""" + +import numpy as np +import pytest + +try: + from nkigen.llvm import LLVMModule + + HAS_NKIGEN = True +except ImportError: + HAS_NKIGEN = False + +from nkipy.core.trace import NKIPyKernel +from nkipy.core.knob import knob + +pytestmark = pytest.mark.skipif( + not HAS_NKIGEN, reason="nkigen not installed" +) + + +def _trace_and_run_llvm(func, *np_args): + """Trace via nkipy nkigen, execute via LLVM JIT, return result.""" + kernel = NKIPyKernel.trace(func, backend="nkigen") + ir = kernel.specialize(*np_args) + mod = LLVMModule(ir._mlir_text, ir._func_name) + return mod(*np_args) + + +def _trace_and_compile_to_neff(func, *np_args): + """Trace a nkigen kernel and compile all the way to NEFF. + + Exercises the full nkipy.core.knob -> builder.annotate() -> MLIR pass + pipeline -> NISA -> neuronx-cc -> NEFF path. Raises on any failure. + """ + import shutil + import tempfile + + from nkipy.core import compile as nkipy_compile + + kernel = NKIPyKernel.trace(func, backend="nkigen") + kernel.specialize(*np_args) + + artifacts_dir = tempfile.mkdtemp(prefix="nkigen_neff_test_") + try: + nkipy_compile.compile_to_neff( + kernel, + artifacts_dir, + additional_compiler_args=nkipy_compile.nkipy_compiler_args, + ) + finally: + shutil.rmtree(artifacts_dir, ignore_errors=True) + + +class TestNumericalLLVMJIT: + """Smoke tests: trace through nkipy, verify numerics via LLVM JIT.""" + + def test_matmul_add(self): + def kernel(a, b, bias): + return np.matmul(a, b) + bias + + a = np.random.randn(4, 8).astype(np.float32) + b = np.random.randn(8, 4).astype(np.float32) + bias = np.random.randn(4).astype(np.float32) + result = _trace_and_run_llvm(kernel, a, b, bias) + np.testing.assert_allclose(result, a @ b + bias, rtol=1e-4, atol=1e-4) + + def test_sigmoid(self): + def kernel(x): + return np.reciprocal(1.0 + np.exp(-x)) + + x = np.random.randn(4, 8).astype(np.float32) + result = _trace_and_run_llvm(kernel, x) + expected = 1.0 / (1.0 + np.exp(-x)) + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + + +class TestNumericalFullPipeline: + """Compile to NEFF end-to-end via nkipy trace+compile. + + Exercises the nkipy.core.knob -> builder.annotate() -> MLIR -> NISA -> + NEFF path, catching issues like mem_space enum mismatches between the + Python builder and the MLIR dialect definition. + """ + + def test_add_full_pipeline(self): + def kernel(a, b): + C = np.add(a, b) + C = knob(C, mem_space="SharedHbm", tile_size=[128, 128]) + return C + + a = np.random.randn(128, 128).astype(np.float32) + b = np.random.randn(128, 128).astype(np.float32) + _trace_and_compile_to_neff(kernel, a, b) + + def test_matmul_full_pipeline(self): + def kernel(a, b): + C = np.matmul(a, b) + C = knob(C, mem_space="SharedHbm", tile_size=[128, 128], + reduction_tile=[128]) + return C + + a = np.random.randn(128, 256).astype(np.float32) + b = np.random.randn(256, 128).astype(np.float32) + _trace_and_compile_to_neff(kernel, a, b) + + def test_sigmoid_full_pipeline(self): + def kernel(x): + neg_x = -x + neg_x = knob(neg_x, mem_space="Sbuf", tile_size=[128, 128]) + exp_neg = np.exp(neg_x) + exp_neg = knob(exp_neg, mem_space="Sbuf", tile_size=[128, 128]) + denom = 1.0 + exp_neg + denom = knob(denom, mem_space="Sbuf", tile_size=[128, 128]) + result = 1.0 / denom + result = knob(result, mem_space="SharedHbm", tile_size=[128, 128]) + return result + + x = np.random.randn(128, 256).astype(np.float32) + _trace_and_compile_to_neff(kernel, x) diff --git a/tests/test_nkigen_ops.py b/tests/test_nkigen_ops.py new file mode 100644 index 0000000..b984aa1 --- /dev/null +++ b/tests/test_nkigen_ops.py @@ -0,0 +1,367 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the nkigen backend integration. + +Each test traces a kernel with backend="nkigen", compiles to NEFF, +runs on Neuron device, and compares the numerical result against NumPy. +When no device is available, falls back to compile-only validation. +""" + +import numpy as np +import pytest + +try: + import nkigen # noqa: F401 + + HAS_NKIGEN = True +except ImportError: + HAS_NKIGEN = False + +from utils import ( + NEURON_AVAILABLE, + baremetal_assert_allclose, + on_device_test, + trace_and_compile, +) + +pytestmark = pytest.mark.skipif( + not HAS_NKIGEN, reason="nkigen not installed" +) + +TRACE_MODE = "nkigen" + + +def _run_kernel(kernel_fn, *args): + """Run a kernel on device if available, else compile-only. Returns result or None.""" + if NEURON_AVAILABLE: + return on_device_test(kernel_fn, TRACE_MODE, *args) + else: + trace_and_compile(kernel_fn, TRACE_MODE, *args) + return None + + +class TestNkiGenBasicOps: + """Test basic arithmetic operations compile and run correctly.""" + + def test_add(self): + def kernel(A, B): + return np.add(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A + B) + + def test_subtract(self): + def kernel(A, B): + return np.subtract(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A - B) + + def test_multiply(self): + def kernel(A, B): + return np.multiply(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A * B) + + def test_scalar_add(self): + def kernel(A): + return A + 1.0 + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, A + 1.0) + + def test_matmul(self): + def kernel(A, B): + return np.matmul(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A @ B) + + def test_matmul_batched(self): + def kernel(A, B): + return np.matmul(A, B) + + A = np.random.randn(2, 128, 128).astype(np.float32) + B = np.random.randn(2, 128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A @ B) + + +class TestNkiGenUnaryOps: + """Test unary operations compile and run correctly.""" + + def test_exp(self): + def kernel(A): + return np.exp(A) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.exp(A)) + + def test_sqrt(self): + def kernel(A): + return np.sqrt(A) + + A = np.abs(np.random.randn(128, 128)).astype(np.float32) + 0.01 + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.sqrt(A)) + + def test_tanh(self): + def kernel(A): + return np.tanh(A) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.tanh(A)) + + def test_negative(self): + def kernel(A): + return -A + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, -A) + + +class TestNkiGenTransformOps: + """Test transform operations compile and run correctly.""" + + @pytest.mark.xfail(reason="linalg.transpose not lowered to NISA yet", run=True, strict=False) + def test_transpose(self): + def kernel(A): + return np.transpose(A) + + A = np.random.randn(128, 256).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, A.T) + + def test_reshape(self): + def kernel(A): + return np.reshape(A, (256, 64)) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, A.reshape(256, 64)) + + def test_squeeze(self): + def kernel(A): + return np.squeeze(A, axis=1) + + A = np.random.randn(128, 1, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, A.squeeze(axis=1)) + + @pytest.mark.xfail(reason="linalg.transpose not lowered to NISA yet", run=True, strict=False) + def test_swapaxes(self): + def kernel(A): + return np.swapaxes(A, 0, 1) + + A = np.random.randn(128, 256).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.swapaxes(A, 0, 1)) + + @pytest.mark.xfail(reason="tensor.insert_slice stack lowering produces incorrect NISA", run=True, strict=False) + def test_stack(self): + def kernel(A, B): + return np.stack([A, B], axis=0) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, np.stack([A, B], axis=0)) + + +class TestNkiGenReductions: + """Test reduction operations compile and run correctly.""" + + def test_sum(self): + def kernel(A): + return np.sum(A, axis=1, keepdims=True) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.sum(A, axis=1, keepdims=True)) + + @pytest.mark.xfail(reason="mean reduction missing memory space annotation in NISA lowering", run=True, strict=False) + def test_mean(self): + def kernel(A): + return np.mean(A, axis=0, keepdims=True) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.mean(A, axis=0, keepdims=True)) + + +class TestNkiGenComparisonOps: + """Test comparison and logical operations compile and run correctly.""" + + def test_equal(self): + def kernel(A, B): + return np.equal(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = A.copy() + B[::2, :] = 0.0 + result = _run_kernel(kernel, A, B) + if result is not None: + expected = np.equal(A, B).astype(np.float32) + baremetal_assert_allclose(result, expected) + + def test_greater(self): + def kernel(A, B): + return np.greater(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + expected = np.greater(A, B).astype(np.float32) + baremetal_assert_allclose(result, expected) + + def test_less_scalar(self): + def kernel(A): + return np.less(A, 0.5) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + expected = np.less(A, 0.5).astype(np.float32) + baremetal_assert_allclose(result, expected) + + def test_logical_not(self): + def kernel(A): + return np.logical_not(A) + + A = np.array(np.random.rand(128, 128) > 0.5, dtype=np.float32) + result = _run_kernel(kernel, A) + if result is not None: + expected = np.logical_not(A).astype(np.float32) + baremetal_assert_allclose(result, expected) + + def test_bitwise_and(self): + def kernel(A, B): + return np.bitwise_and(A, B) + + A = np.array(np.random.rand(128, 128) > 0.5, dtype=np.float32) + B = np.array(np.random.rand(128, 128) > 0.5, dtype=np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + expected = np.bitwise_and( + A.astype(np.int32), B.astype(np.int32) + ).astype(np.float32) + baremetal_assert_allclose(result, expected) + + +class TestNkiGenWhere: + """Test np.where compiles and runs correctly.""" + + def test_where_same_type(self): + def kernel(A, B, C): + return np.where(A, B, C) + + A = np.array(np.random.rand(128, 128) > 0.5, dtype=np.float32) + B = np.random.randn(128, 128).astype(np.float32) + C = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B, C) + if result is not None: + baremetal_assert_allclose(result, np.where(A, B, C)) + + def test_where_with_comparison(self): + def kernel(A, B): + mask = np.greater(A, 0.0) + return np.where(mask, A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, np.where(A > 0.0, A, B)) + + +class TestNkiGenComposedKernel: + """Test non-trivial kernels that compose multiple ops.""" + + @pytest.mark.xfail(reason="broadcast add with rank-1 bias not lowered to NISA yet", run=True, strict=False) + def test_matmul_add_relu(self): + def kernel(A, B, bias): + C = np.matmul(A, B) + C = C + bias + return np.maximum(C, 0.0) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + bias = np.random.randn(128).astype(np.float32) + result = _run_kernel(kernel, A, B, bias) + if result is not None: + baremetal_assert_allclose(result, np.maximum(A @ B + bias, 0.0)) + + @pytest.mark.xfail(reason="composed mean/sqrt/broadcast not lowered to NISA yet", run=True, strict=False) + def test_rmsnorm(self): + def kernel(x, weight): + variance = np.mean(x * x, axis=-1, keepdims=True) + x_norm = x / np.sqrt(variance + 1e-6) + return x_norm * weight + + x = np.random.randn(128, 128).astype(np.float32) + w = np.random.randn(128).astype(np.float32) + result = _run_kernel(kernel, x, w) + if result is not None: + variance = np.mean(x * x, axis=-1, keepdims=True) + expected = (x / np.sqrt(variance + 1e-6)) * w + baremetal_assert_allclose(result, expected) + + @pytest.mark.xfail(reason="clip not lowered to NISA yet", run=True, strict=False) + def test_clip(self): + def kernel(A): + return np.clip(A, 0.0, 1.0) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.clip(A, 0.0, 1.0)) + + +class TestNkiGenAnnotations: + """Test knob() annotations compile to NEFF and run correctly.""" + + def test_knob_mem_space(self): + from nkipy.core.knob import knob + + def kernel(A, B): + C = np.matmul(A, B) + C = knob(C, mem_space="SharedHbm", tile_size=[128, 128], + reduction_tile=[128]) + return C + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A @ B) diff --git a/tests/unit/test_core_ops_direct.py b/tests/unit/test_core_ops_direct.py index e3f4b34..ec8b669 100644 --- a/tests/unit/test_core_ops_direct.py +++ b/tests/unit/test_core_ops_direct.py @@ -342,7 +342,7 @@ def test_reduce_unsupported_op(self, trace_mode): """_build_reduction_hlo with unsupported op raises NotImplementedError.""" def kernel(x): - from nkipy.core.ops.reduce import _build_reduction_hlo + from nkipy.core.ops._hlo_impls import _build_reduction_hlo return _build_reduction_hlo(x, np.cumsum) diff --git a/tests/unit/test_device_kernel_cache.py b/tests/unit/test_device_kernel_cache.py index 07b9949..2cc4643 100644 --- a/tests/unit/test_device_kernel_cache.py +++ b/tests/unit/test_device_kernel_cache.py @@ -1,14 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for HLO-based kernel cache hashing.""" +"""Unit tests for IR content hashing.""" import numpy as np from nkipy.core.compile import trace -from nkipy.runtime.device_kernel import _hlo_content_hash def _trace_and_specialize(kernel_fn, *args, **kwargs): - """Helper: trace a kernel, specialize with given args, return the HLOModule.""" + """Helper: trace a kernel, specialize with given args, return the IR.""" traced = trace(kernel_fn) traced.specialize(*args, **kwargs) return traced._code @@ -20,18 +19,18 @@ def test_hlo_hash_varies_with_shape(): def add_kernel(x, y): return np.add(x, y) - hlo_small = _trace_and_specialize( + ir_small = _trace_and_specialize( add_kernel, np.zeros((2, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32), ) - hlo_large = _trace_and_specialize( + ir_large = _trace_and_specialize( add_kernel, np.zeros((4, 4), dtype=np.float32), np.zeros((4, 4), dtype=np.float32), ) - assert _hlo_content_hash(hlo_small, "") != _hlo_content_hash(hlo_large, "") + assert ir_small.content_hash("") != ir_large.content_hash("") def test_hlo_hash_deterministic(): @@ -45,10 +44,10 @@ def add_kernel(x, y): np.zeros((2, 2), dtype=np.float32), ) - hlo1 = _trace_and_specialize(add_kernel, *inputs) - hlo2 = _trace_and_specialize(add_kernel, *inputs) + ir1 = _trace_and_specialize(add_kernel, *inputs) + ir2 = _trace_and_specialize(add_kernel, *inputs) - assert _hlo_content_hash(hlo1, "--lnc 1") == _hlo_content_hash(hlo2, "--lnc 1") + assert ir1.content_hash("--lnc 1") == ir2.content_hash("--lnc 1") def test_hlo_hash_varies_with_dtype(): @@ -57,18 +56,18 @@ def test_hlo_hash_varies_with_dtype(): def add_kernel(x, y): return np.add(x, y) - hlo_f32 = _trace_and_specialize( + ir_f32 = _trace_and_specialize( add_kernel, np.zeros((2, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32), ) - hlo_f16 = _trace_and_specialize( + ir_f16 = _trace_and_specialize( add_kernel, np.zeros((2, 2), dtype=np.float16), np.zeros((2, 2), dtype=np.float16), ) - assert _hlo_content_hash(hlo_f32, "") != _hlo_content_hash(hlo_f16, "") + assert ir_f32.content_hash("") != ir_f16.content_hash("") def test_hlo_hash_varies_with_compiler_args(): @@ -77,12 +76,12 @@ def test_hlo_hash_varies_with_compiler_args(): def add_kernel(x, y): return np.add(x, y) - hlo = _trace_and_specialize( + ir = _trace_and_specialize( add_kernel, np.zeros((2, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32), ) - hash1 = _hlo_content_hash(hlo, "--lnc 1") - hash2 = _hlo_content_hash(hlo, "--lnc 2") + hash1 = ir.content_hash("--lnc 1") + hash2 = ir.content_hash("--lnc 2") assert hash1 != hash2 diff --git a/tests/unit/test_nkigen_backend.py b/tests/unit/test_nkigen_backend.py new file mode 100644 index 0000000..9327291 --- /dev/null +++ b/tests/unit/test_nkigen_backend.py @@ -0,0 +1,424 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the unified nkigen backend integration. + +Ported from NeuronPy/tests/unit/test_nkigen_backend.py with adaptations +for the current nkipy implementation. +""" + +import warnings + +import numpy as np +import pytest + +from nkipy import knob +from nkipy.core.nki_op import nki_custom_op +from nkipy.core.backend import get_backend, tracing +from nkipy.core.backend.nkigen import NkiGenIR, NkiGenTraceContext + + +class TestKnobDispatch: + """Test knob() backend-aware dispatch.""" + + def test_knob_cpu_passthrough(self): + """knob() is a no-op pass-through in cpu mode (no trace).""" + arr = np.ones((4, 4), dtype=np.float32) + result = knob(arr, mem_space="Sbuf") + assert result is arr + + def test_knob_cpu_no_params(self): + """knob() with no params is always a no-op.""" + arr = np.ones((4, 4), dtype=np.float32) + result = knob(arr) + assert result is arr + + def test_knob_hlo_warns(self): + """knob() issues a warning under the HLO backend.""" + from nkipy.core.backend.hlo import HLOModule, HLOTraceContext + + code = HLOModule(name="test") + arr = np.ones((4, 4), dtype=np.float32) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with tracing(HLOTraceContext(code)): + result = knob(arr, mem_space="Sbuf") + assert len(w) == 1 + assert "only effective with backend='nkigen'" in str(w[0].message) + assert result is arr + + +def _make_silu_kernel_builder(M, N, tile_p=128, tile_f=128): + """Return a real NKI kernel_builder function that computes SiLU activation.""" + def silu_kernel(input_0, output_0): + import nki.compiler.kernel_builder as nb + import nki.language as nl + + n_row_tiles = M // tile_p + n_col_tiles = N // tile_f + for r in nl.affine_range(n_row_tiles): + for t in nl.affine_range(n_col_tiles): + x_sbuf = nb.ndarray((tile_p, tile_f), input_0.dtype, nb.sbuf) + nb.isa.dma_copy( + dst=x_sbuf, + src=input_0[ + r * tile_p : (r + 1) * tile_p, + t * tile_f : (t + 1) * tile_f, + ], + ) + + out_sbuf = nb.ndarray((tile_p, tile_f), input_0.dtype, nb.sbuf) + bias = nb.ndarray((tile_p, 1), input_0.dtype, nb.sbuf) + nb.isa.memset(dst=bias, value=0.0) + scale = nb.ndarray((tile_p, 1), input_0.dtype, nb.sbuf) + nb.isa.memset(dst=scale, value=1.0) + + nb.isa.activation( + dst=out_sbuf, + src=x_sbuf, + bias=bias, + scale=scale, + op=nb.isa.activation_function.silu, + ) + + nb.isa.dma_copy( + dst=output_0[ + r * tile_p : (r + 1) * tile_p, + t * tile_f : (t + 1) * tile_f, + ], + src=out_sbuf, + ) + return silu_kernel + + +class TestNKICustomOpDispatch: + """Test nki_custom_op() factory and dispatch.""" + + def test_requires_at_least_one(self): + """nki_custom_op raises if neither nki_kernel nor kernel_builder given.""" + with pytest.raises(ValueError, match="At least one"): + nki_custom_op() + + def test_kernel_builder_requires_specs(self): + """nki_custom_op raises if kernel_builder given without specs.""" + with pytest.raises(ValueError, match="input_specs and output_specs"): + nki_custom_op(kernel_builder=lambda: None) + + def test_cpu_raises(self): + """nki_custom_op raises on cpu backend.""" + op = nki_custom_op( + kernel_builder=lambda: None, + input_specs=[((4, 4), "f32")], + output_specs=[((4, 4), "f32")], + ) + with pytest.raises(RuntimeError, match="not supported on backend 'cpu'"): + op(np.ones((4, 4), dtype=np.float32)) + + def test_hlo_without_nki_kernel_raises(self): + """nki_custom_op with only kernel_builder raises on HLO.""" + + class _FakeHLOCtx: + backend_name = "hlo" + + op = nki_custom_op( + kernel_builder=lambda: None, + input_specs=[((128, 128), "f32")], + output_specs=[((128, 128), "f32")], + ) + with tracing(_FakeHLOCtx()): + with pytest.raises(RuntimeError, match="no nki_kernel"): + op(np.ones((128, 128), dtype=np.float32)) + + def test_nkigen_without_kernel_builder_raises(self): + """nki_custom_op with only nki_kernel raises on nkigen.""" + + class _FakeNkigenCtx: + backend_name = "nkigen" + + op = nki_custom_op(nki_kernel=lambda: None) + with tracing(_FakeNkigenCtx()): + with pytest.raises(RuntimeError, match="no kernel_builder"): + op(np.ones((128, 128), dtype=np.float32)) + + +class TestNkiGenTraceContext: + """Test NkiGenTraceContext basics.""" + + @pytest.fixture(autouse=True) + def _skip_if_no_nkigen(self): + try: + import nkigen # noqa: F401 + except ImportError: + pytest.skip("nkigen not installed") + + def test_backend_name(self): + ctx = NkiGenTraceContext() + assert ctx.backend_name == "nkigen" + ctx._cleanup() + + def test_tracing_context_activates(self): + ctx = NkiGenTraceContext() + assert get_backend() == "cpu" + with tracing(ctx): + assert get_backend() == "nkigen" + assert get_backend() == "cpu" + ctx._cleanup() + + +class TestSpecializeNkigen: + """Test NKIPyKernel._specialize_nkigen with device compilation and execution.""" + + @pytest.fixture(autouse=True) + def _skip_if_no_nkigen(self): + try: + import nkigen # noqa: F401 + except ImportError: + pytest.skip("nkigen not installed") + + @staticmethod + def _run(func, *np_args): + from utils import NEURON_AVAILABLE, on_device_test, trace_and_compile + if NEURON_AVAILABLE: + return on_device_test(func, "nkigen", *np_args) + else: + trace_and_compile(func, "nkigen", *np_args) + return None + + def test_with_knob(self): + from utils import baremetal_assert_allclose + + def kernel_with_knob(a, b): + result = a + b + knob(result, mem_space="SharedHbm", tile_size=[128, 128]) + return result + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + result = self._run(kernel_with_knob, a, b) + if result is not None: + baremetal_assert_allclose(result, a + b) + + def test_multi_output(self): + from utils import baremetal_assert_allclose + + def multi_out(a, b): + return a + b, a - b + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + result = self._run(multi_out, a, b) + if result is not None: + baremetal_assert_allclose(result[0], a + b) + baremetal_assert_allclose(result[1], a - b) + + def test_dtype_downcast(self): + """float64 inputs should be auto-downcast to float32.""" + from nkipy.core.trace import NKIPyKernel + + def add_kernel(a, b): + return a + b + + kernel = NKIPyKernel.trace(add_kernel, backend="nkigen") + a = np.random.randn(64, 64) # float64 + b = np.random.randn(64, 64) # float64 + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + ir = kernel.specialize(a, b) + assert ir.inputs[0].dtype == np.dtype("float32") + + @pytest.mark.xfail( + reason="custom_op kernel_builder tracing requires TracedArray, " + "not yet wired through NKIPyTensorRef path" + ) + def test_custom_op_with_kernel_builder(self): + """nki_custom_op with real kernel_builder traces through nkigen backend.""" + from nkipy.core.trace import NKIPyKernel + + silu_op = nki_custom_op( + kernel_builder=_make_silu_kernel_builder(256, 256), + input_specs=[((256, 256), "f32")], + output_specs=[((256, 256), "f32")], + ) + + def kernel(x): + return silu_op(x) + + k = NKIPyKernel.trace(kernel, backend="nkigen") + ir = k.specialize(np.random.randn(256, 256).astype("float32")) + assert isinstance(ir, NkiGenIR) + assert "__custom_op__silu_kernel" in ir._mlir_text + assert "nkipy.custom_op_bodies" in ir._mlir_text + + +class TestNkigenInplaceUpdate: + """Test in-place update (dynamic_update_slice) support for nkigen. + + Each test traces → compiles → runs on device and compares + numerical results against NumPy. Alias metadata is verified as well. + """ + + @pytest.fixture(autouse=True) + def _skip_if_no_nkigen(self): + try: + import nkigen # noqa: F401 + except ImportError: + pytest.skip("nkigen not installed") + + @staticmethod + def _trace_and_run(func, *np_args): + """Trace a nkigen kernel, return (ir, device_result_or_None).""" + from nkipy.core.trace import NKIPyKernel + from utils import NEURON_AVAILABLE, on_device_test, trace_and_compile + + kernel = NKIPyKernel.trace(func, backend="nkigen") + ir = kernel.specialize(*np_args) + if NEURON_AVAILABLE: + result = on_device_test(func, "nkigen", *np_args) + else: + trace_and_compile(func, "nkigen", *np_args) + result = None + return ir, result + + def test_single_alias(self): + """Mutate one parameter and return it — verify numerical result.""" + from utils import baremetal_assert_allclose + + def kernel(a, b): + a[0:1, :] = b[1:2, :] + return a + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + expected = a.copy() + expected[0:1, :] = b[1:2, :] + + ir, result = self._trace_and_run(kernel, a, b) + if result is not None: + baremetal_assert_allclose(result, expected) + + assert isinstance(ir, NkiGenIR) + assert len(ir.aliases) == 1 + assert ir.aliases[0].param_name == "a" + assert ir.aliases[0].param_index == 0 + assert ir.aliases[0].is_user_returned is True + assert ir.auto_aliased_indices == set() + + def test_multi_slice_update(self): + """Update multiple disjoint slices of the same tensor.""" + from utils import baremetal_assert_allclose + + def kernel(a, b): + a[0:2, :] = b[0:2, :] + a[4:6, :] = b[4:6, :] + return a + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + expected = a.copy() + expected[0:2, :] = b[0:2, :] + expected[4:6, :] = b[4:6, :] + + ir, result = self._trace_and_run(kernel, a, b) + if result is not None: + baremetal_assert_allclose(result, expected) + assert len(ir.aliases) == 1 + + def test_multi_alias(self): + """Mutate two parameters and return both.""" + from utils import baremetal_assert_allclose + + def kernel(a, b, c): + a[0:1, :] = b[0:1, :] + c[2:3, :] = b[2:3, :] + return a, c + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + c = np.random.randn(128, 128).astype("float32") + + expected_a = a.copy() + expected_a[0:1, :] = b[0:1, :] + expected_c = c.copy() + expected_c[2:3, :] = b[2:3, :] + + ir, result = self._trace_and_run(kernel, a, b, c) + if result is not None: + baremetal_assert_allclose(result[0], expected_a) + baremetal_assert_allclose(result[1], expected_c) + + assert len(ir.aliases) == 2 + alias_names = {al.param_name for al in ir.aliases} + assert alias_names == {"a", "c"} + assert all(al.is_user_returned for al in ir.aliases) + + def test_no_return_auto_alias(self): + """Mutate without returning — auto-append to outputs.""" + from utils import baremetal_assert_allclose + + def kernel(a, b): + a[0:1, :] = b[1:2, :] + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + expected = a.copy() + expected[0:1, :] = b[1:2, :] + + ir, result = self._trace_and_run(kernel, a, b) + if result is not None: + baremetal_assert_allclose(result, expected) + + assert len(ir.aliases) == 1 + assert ir.aliases[0].param_name == "a" + assert ir.aliases[0].is_user_returned is False + assert ir.auto_aliased_indices == {0} + + def test_mixed_return_alias(self): + """Mutate a parameter but return a different computed value.""" + from utils import baremetal_assert_allclose + + def kernel(a, b): + a[0:1, :] = b[1:2, :] + return a + b + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + expected_a = a.copy() + expected_a[0:1, :] = b[1:2, :] + expected_sum = expected_a + b + + ir, result = self._trace_and_run(kernel, a, b) + if result is not None: + baremetal_assert_allclose(result, expected_sum) + + assert len(ir.aliases) == 1 + assert ir.aliases[0].param_name == "a" + assert ir.aliases[0].is_user_returned is False + assert len(ir.outputs) == 2 + assert ir.auto_aliased_indices == {1} + + def test_update_with_computation(self): + """Assign a computed expression into a slice.""" + from utils import baremetal_assert_allclose + + def kernel(a, b): + a[0:2, :] = b[0:2, :] * 2.0 + return a + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + expected = a.copy() + expected[0:2, :] = b[0:2, :] * 2.0 + + ir, result = self._trace_and_run(kernel, a, b) + if result is not None: + baremetal_assert_allclose(result, expected) + + diff --git a/tests/utils.py b/tests/utils.py index 8cc8baa..9b32582 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,28 +29,30 @@ NEURON_AVAILABLE = is_neuron_compatible() +def _trace_mode_to_backend(trace_mode): + if trace_mode in ("hlo", "nkigen"): + return trace_mode + raise ValueError(f"Unknown trace mode: {trace_mode}") + + def trace_and_compile(kernel_fn, trace_mode, *args, **kwargs): """ - Validate kernel is traceable to HLO and compilable to NEFF. + Validate kernel is traceable and compilable to NEFF. - Traces the kernel to HLO IR and compiles it using the Neuron compiler, + Traces the kernel to IR and compiles it using the Neuron compiler, but does not execute on device. Args: kernel_fn: The kernel function to test - trace_mode: "hlo" or other supported tracing mode + trace_mode: "hlo" or "nkigen" *args: Input arrays for the kernel **kwargs: Additional arguments """ - if trace_mode == "hlo": - traced_kernel = NKIPyKernel.trace(kernel_fn, backend="hlo") - else: - raise ValueError(f"Unknown trace mode: {trace_mode}") + backend = _trace_mode_to_backend(trace_mode) + traced_kernel = NKIPyKernel.trace(kernel_fn, backend=backend) - # Trace to HLO traced_kernel.specialize(*args, **kwargs) - # Compile to NEFF worker_id = os.environ.get("PYTEST_XDIST_WORKER", "main") artifacts_dir = os.path.join(tempfile.gettempdir(), f"nkipy_artifacts_{worker_id}") if os.path.exists(artifacts_dir): @@ -70,7 +72,7 @@ def on_device_test(kernel_fn, trace_mode, *args, artifacts_dir=None, **kwargs): Args: kernel_fn: The kernel function to execute - trace_mode: "hlo" or other supported tracing mode + trace_mode: "hlo" or "nkigen" *args: Input arrays for the kernel artifacts_dir: Directory for compilation artifacts (for parallel test isolation) **kwargs: Additional arguments @@ -78,17 +80,14 @@ def on_device_test(kernel_fn, trace_mode, *args, artifacts_dir=None, **kwargs): Returns: Device execution output """ - # Auto-generate worker-specific artifacts_dir if not provided if artifacts_dir is None: worker_id = os.environ.get("PYTEST_XDIST_WORKER", "main") artifacts_dir = os.path.join( tempfile.gettempdir(), f"nkipy_artifacts_{worker_id}" ) - if trace_mode == "hlo": - traced_kernel = NKIPyKernel.trace(kernel_fn, backend="hlo") - else: - raise ValueError(f"Unknown trace mode: {trace_mode}") + backend = _trace_mode_to_backend(trace_mode) + traced_kernel = NKIPyKernel.trace(kernel_fn, backend=backend) return baremetal_run_traced_kernel( traced_kernel, *args, artifacts_dir=artifacts_dir, **kwargs