From a20d42f0118bf9d1fcb4952df8560a777e38180f Mon Sep 17 00:00:00 2001 From: zhengzuohe Date: Sat, 9 May 2026 09:30:01 +0800 Subject: [PATCH] Add: cross-class @pl.jit.inline example Demonstrates splitting inlined kernels across files and classes while retaining @pl.jit dep auto-discovery via module-level aliases. - proj_lib.py: tiled matmul helper (BF16 x BF16 -> FP32) on Projections - eltwise_lib.py: residual-add helper on Elementwise - main.py: entry function that inlines both helpers into a single kernel - config.py: shared tiling constants (BATCH, HIDDEN, chunk sizes) --- .../advanced/cross_class_proj/__init__.py | 15 +++ examples/advanced/cross_class_proj/config.py | 19 +++ .../advanced/cross_class_proj/eltwise_lib.py | 39 ++++++ examples/advanced/cross_class_proj/main.py | 115 ++++++++++++++++++ .../advanced/cross_class_proj/proj_lib.py | 57 +++++++++ 5 files changed, 245 insertions(+) create mode 100644 examples/advanced/cross_class_proj/__init__.py create mode 100644 examples/advanced/cross_class_proj/config.py create mode 100644 examples/advanced/cross_class_proj/eltwise_lib.py create mode 100644 examples/advanced/cross_class_proj/main.py create mode 100644 examples/advanced/cross_class_proj/proj_lib.py diff --git a/examples/advanced/cross_class_proj/__init__.py b/examples/advanced/cross_class_proj/__init__.py new file mode 100644 index 00000000..b405399f --- /dev/null +++ b/examples/advanced/cross_class_proj/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Cross-class / cross-file ``@pl.jit.inline`` example package. + +Run with:: + + cd pypto-lib + python examples/advanced/cross_class_proj/main.py -p a2a3sim +""" diff --git a/examples/advanced/cross_class_proj/config.py b/examples/advanced/cross_class_proj/config.py new file mode 100644 index 00000000..e24a5741 --- /dev/null +++ b/examples/advanced/cross_class_proj/config.py @@ -0,0 +1,19 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Shared problem-size and tiling constants for the cross-class example.""" + +BATCH = 16 +HIDDEN = 8192 + +# Projection tiling +N_OUT_CHUNK = 256 # N tile per parallel core-group +K_PROJ_CHUNK = 128 # K reduction tile inside each scope + +# Elementwise tiling +ADD_OUT_CHUNK = 256 # column tile per parallel core-group for residual add diff --git a/examples/advanced/cross_class_proj/eltwise_lib.py b/examples/advanced/cross_class_proj/eltwise_lib.py new file mode 100644 index 00000000..f6f3f19f --- /dev/null +++ b/examples/advanced/cross_class_proj/eltwise_lib.py @@ -0,0 +1,39 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Elementwise kernels packaged as static methods on an ``Elementwise`` class. + +Same dispatch pattern as ``proj_lib.Projections``: the ``@pl.jit.inline`` body +lives inside a class for namespacing, and a module-level alias re-exports it +so ``@pl.jit`` dep auto-discovery (which only matches bare-name calls) finds +the helper from the entry function. +""" + +import pypto.language as pl + +from config import ADD_OUT_CHUNK, BATCH, HIDDEN + + +class Elementwise: + """Elementwise helpers used after the projection step.""" + @pl.jit.inline + def residual_add( + a: pl.Tensor[[BATCH, HIDDEN], pl.FP32], + b: pl.Tensor[[BATCH, HIDDEN], pl.BF16], + out: pl.Out[pl.Tensor[[BATCH, HIDDEN], pl.FP32]], + ): + """``out = a + cast(b, FP32)`` — N parallel, no K reduction.""" + for n0 in pl.parallel(0, HIDDEN, ADD_OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="residual_add"): + a_tile = a[:, n0 : n0 + ADD_OUT_CHUNK] + b_tile = b[:, n0 : n0 + ADD_OUT_CHUNK] + b_f32 = pl.cast(b_tile, target_type=pl.FP32) + sum_tile = pl.add(a_tile, b_f32) + out = pl.assemble(out, sum_tile, [0, n0]) + return out + diff --git a/examples/advanced/cross_class_proj/main.py b/examples/advanced/cross_class_proj/main.py new file mode 100644 index 00000000..aaa8a858 --- /dev/null +++ b/examples/advanced/cross_class_proj/main.py @@ -0,0 +1,115 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Cross-class / cross-file ``@pl.jit.inline`` example. + + output = (x @ w) + hidden_states + +The two stages live on two different classes in two different files: + +* ``proj_lib.Projections.linear`` — tiled matmul (BF16 x BF16 -> FP32) +* ``eltwise_lib.Elementwise.residual_add`` — FP32 tensor + cast(BF16 -> FP32) + +Both are decorated with ``@pl.jit.inline``, so the ``InlineFunctions`` IR +pass splices their bodies into this entry function during compilation — +producing the same lowered IR as a hand-fused single-function kernel, but +with the source split across files for reuse. + +Run with:: + + python examples/advanced/cross_class_proj/main.py -p a2a3sim +""" + +import pypto.language as pl + +from config import BATCH, HIDDEN +from eltwise_lib import Elementwise +from proj_lib import Projections + +linear = Projections.linear +residual_add = Elementwise.residual_add + +class ProjResidual: + @pl.jit + def proj_residual( + x: pl.Tensor[[BATCH, HIDDEN], pl.BF16], + w: pl.Tensor[[HIDDEN, HIDDEN], pl.BF16], + hidden_states: pl.Tensor[[BATCH, HIDDEN], pl.BF16], + out: pl.Out[pl.Tensor[[BATCH, HIDDEN], pl.FP32]], + ): + # Stage 0: linear projection from another class in another file. + proj_out = pl.create_tensor([BATCH, HIDDEN], dtype=pl.FP32) + proj_out = linear(x, w, proj_out) + + # Stage 1: residual add from yet another class in another file. + out = residual_add(proj_out, hidden_states, out) + return out + + +def build_tensor_specs(): + import torch + + from golden import TensorSpec + + scale = HIDDEN ** 0.5 + + def init_x(): + return torch.rand(BATCH, HIDDEN) - 0.5 + + def init_w(): + return (torch.rand(HIDDEN, HIDDEN) - 0.5) / scale + + def init_h(): + return torch.rand(BATCH, HIDDEN) - 0.5 + + return [ + TensorSpec("x", [BATCH, HIDDEN], torch.bfloat16, init_value=init_x), + TensorSpec("w", [HIDDEN, HIDDEN], torch.bfloat16, init_value=init_w), + TensorSpec("hidden_states", [BATCH, HIDDEN], torch.bfloat16, init_value=init_h), + TensorSpec("out", [BATCH, HIDDEN], torch.float32, is_output=True), + ] + + +def golden_proj_residual(tensors): + x_f32 = tensors["x"].float() + w_f32 = tensors["w"].float() + h_f32 = tensors["hidden_states"].float() + tensors["out"][:] = x_f32 @ w_f32 + h_f32 + + +if __name__ == "__main__": + import argparse + + from golden import RunConfig, run_jit + + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--platform", type=str, default="a2a3", + choices=["a2a3", "a2a3sim", "a5", "a5sim"]) + parser.add_argument("-d", "--device", type=int, default=0) + parser.add_argument("--runtime-profiling", action="store_true", default=False) + args = parser.parse_args() + + result = run_jit( + fn=ProjResidual.proj_residual, + specs=build_tensor_specs(), + golden_fn=golden_proj_residual, + config=RunConfig( + rtol=1e-3, + atol=1e-3, + compile=dict(dump_passes=True), + runtime=dict( + platform=args.platform, + device_id=args.device, + runtime_profiling=args.runtime_profiling, + ), + ), + ) + if not result.passed: + if result.error: + print(result.error) + raise SystemExit(1) diff --git a/examples/advanced/cross_class_proj/proj_lib.py b/examples/advanced/cross_class_proj/proj_lib.py new file mode 100644 index 00000000..1cdcdd1b --- /dev/null +++ b/examples/advanced/cross_class_proj/proj_lib.py @@ -0,0 +1,57 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Projection kernels packaged as static methods on a ``Projections`` class. + +Demonstrates organising ``@pl.jit.inline`` helpers behind a class namespace +while keeping them discoverable by ``@pl.jit`` dep auto-discovery. + +Why the module-level alias? +--------------------------- +``_discover_deps`` in ``pypto/python/pypto/jit/decorator.py`` only scans the +entry function's AST for **bare-name** calls (``ast.Call`` whose ``func`` is +an ``ast.Name``). Method-style calls like ``Projections.linear(...)`` are +``ast.Attribute`` and are *not* picked up. Re-exporting the static method as +a module-level binding (``linear = Projections.linear``) lets the entry +function call it as ``linear(...)`` so dep discovery succeeds. +""" + +import pypto.language as pl + +from config import BATCH, HIDDEN, K_PROJ_CHUNK, N_OUT_CHUNK + + +class Projections: + """Linear-projection helpers grouped under a class namespace. + + Each method is decorated with ``@pl.jit.inline`` so its body is spliced + into the caller by the ``InlineFunctions`` IR pass. ``@staticmethod`` is + layered on top so Python returns the underlying ``JITFunction`` when the + attribute is read off the class (``Projections.linear``). + """ + @pl.jit.inline + def linear( + x: pl.Tensor[[BATCH, HIDDEN], pl.BF16], + w: pl.Tensor[[HIDDEN, HIDDEN], pl.BF16], + y: pl.Out[pl.Tensor[[BATCH, HIDDEN], pl.FP32]], + ): + """``y = x @ w`` — N parallel, K reduction pipelined inside each scope.""" + for n0 in pl.parallel(0, HIDDEN, N_OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="linear"): + acc = pl.create_tensor([BATCH, N_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, HIDDEN // K_PROJ_CHUNK, stage=2): + k0 = kb * K_PROJ_CHUNK + tile_x = x[:, k0 : k0 + K_PROJ_CHUNK] + tile_w = w[k0 : k0 + K_PROJ_CHUNK, n0 : n0 + N_OUT_CHUNK] + if k0 == 0: + acc = pl.matmul(tile_x, tile_w, out_dtype=pl.FP32) + else: + acc = pl.matmul_acc(acc, tile_x, tile_w) + y = pl.assemble(y, acc, [0, n0]) + return y +