Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/advanced/cross_class_proj/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""Cross-class / cross-file ``@pl.jit.inline`` example package.

Run with::

cd pypto-lib
python examples/advanced/cross_class_proj/main.py -p a2a3sim
"""
19 changes: 19 additions & 0 deletions examples/advanced/cross_class_proj/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""Shared problem-size and tiling constants for the cross-class example."""

BATCH = 16
HIDDEN = 8192

# Projection tiling
N_OUT_CHUNK = 256 # N tile per parallel core-group
K_PROJ_CHUNK = 128 # K reduction tile inside each scope

# Elementwise tiling
ADD_OUT_CHUNK = 256 # column tile per parallel core-group for residual add
39 changes: 39 additions & 0 deletions examples/advanced/cross_class_proj/eltwise_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""Elementwise kernels packaged as static methods on an ``Elementwise`` class.

Same dispatch pattern as ``proj_lib.Projections``: the ``@pl.jit.inline`` body
lives inside a class for namespacing, and a module-level alias re-exports it
so ``@pl.jit`` dep auto-discovery (which only matches bare-name calls) finds
the helper from the entry function.
"""

import pypto.language as pl

from config import ADD_OUT_CHUNK, BATCH, HIDDEN


class Elementwise:
"""Elementwise helpers used after the projection step."""
@pl.jit.inline
def residual_add(
a: pl.Tensor[[BATCH, HIDDEN], pl.FP32],
b: pl.Tensor[[BATCH, HIDDEN], pl.BF16],
out: pl.Out[pl.Tensor[[BATCH, HIDDEN], pl.FP32]],
):
"""``out = a + cast(b, FP32)`` — N parallel, no K reduction."""
for n0 in pl.parallel(0, HIDDEN, ADD_OUT_CHUNK):
with pl.at(level=pl.Level.CORE_GROUP, name_hint="residual_add"):
a_tile = a[:, n0 : n0 + ADD_OUT_CHUNK]
b_tile = b[:, n0 : n0 + ADD_OUT_CHUNK]
b_f32 = pl.cast(b_tile, target_type=pl.FP32)
sum_tile = pl.add(a_tile, b_f32)
out = pl.assemble(out, sum_tile, [0, n0])
return out

115 changes: 115 additions & 0 deletions examples/advanced/cross_class_proj/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""Cross-class / cross-file ``@pl.jit.inline`` example.

output = (x @ w) + hidden_states

The two stages live on two different classes in two different files:

* ``proj_lib.Projections.linear`` — tiled matmul (BF16 x BF16 -> FP32)
* ``eltwise_lib.Elementwise.residual_add`` — FP32 tensor + cast(BF16 -> FP32)

Both are decorated with ``@pl.jit.inline``, so the ``InlineFunctions`` IR
pass splices their bodies into this entry function during compilation —
producing the same lowered IR as a hand-fused single-function kernel, but
with the source split across files for reuse.

Run with::

python examples/advanced/cross_class_proj/main.py -p a2a3sim
"""

import pypto.language as pl

from config import BATCH, HIDDEN
from eltwise_lib import Elementwise
from proj_lib import Projections

linear = Projections.linear
residual_add = Elementwise.residual_add

class ProjResidual:
@pl.jit
def proj_residual(
x: pl.Tensor[[BATCH, HIDDEN], pl.BF16],
w: pl.Tensor[[HIDDEN, HIDDEN], pl.BF16],
hidden_states: pl.Tensor[[BATCH, HIDDEN], pl.BF16],
out: pl.Out[pl.Tensor[[BATCH, HIDDEN], pl.FP32]],
):
# Stage 0: linear projection from another class in another file.
proj_out = pl.create_tensor([BATCH, HIDDEN], dtype=pl.FP32)
proj_out = linear(x, w, proj_out)

# Stage 1: residual add from yet another class in another file.
out = residual_add(proj_out, hidden_states, out)
return out


def build_tensor_specs():
import torch

from golden import TensorSpec

scale = HIDDEN ** 0.5

def init_x():
return torch.rand(BATCH, HIDDEN) - 0.5

def init_w():
return (torch.rand(HIDDEN, HIDDEN) - 0.5) / scale

def init_h():
return torch.rand(BATCH, HIDDEN) - 0.5

return [
TensorSpec("x", [BATCH, HIDDEN], torch.bfloat16, init_value=init_x),
TensorSpec("w", [HIDDEN, HIDDEN], torch.bfloat16, init_value=init_w),
TensorSpec("hidden_states", [BATCH, HIDDEN], torch.bfloat16, init_value=init_h),
TensorSpec("out", [BATCH, HIDDEN], torch.float32, is_output=True),
]


def golden_proj_residual(tensors):
x_f32 = tensors["x"].float()
w_f32 = tensors["w"].float()
h_f32 = tensors["hidden_states"].float()
tensors["out"][:] = x_f32 @ w_f32 + h_f32


if __name__ == "__main__":
import argparse

from golden import RunConfig, run_jit

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--platform", type=str, default="a2a3",
choices=["a2a3", "a2a3sim", "a5", "a5sim"])
parser.add_argument("-d", "--device", type=int, default=0)
parser.add_argument("--runtime-profiling", action="store_true", default=False)
args = parser.parse_args()

result = run_jit(
fn=ProjResidual.proj_residual,
specs=build_tensor_specs(),
golden_fn=golden_proj_residual,
config=RunConfig(
rtol=1e-3,
atol=1e-3,
compile=dict(dump_passes=True),
runtime=dict(
platform=args.platform,
device_id=args.device,
runtime_profiling=args.runtime_profiling,
),
),
)
if not result.passed:
if result.error:
print(result.error)
raise SystemExit(1)
57 changes: 57 additions & 0 deletions examples/advanced/cross_class_proj/proj_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""Projection kernels packaged as static methods on a ``Projections`` class.

Demonstrates organising ``@pl.jit.inline`` helpers behind a class namespace
while keeping them discoverable by ``@pl.jit`` dep auto-discovery.

Why the module-level alias?
---------------------------
``_discover_deps`` in ``pypto/python/pypto/jit/decorator.py`` only scans the
entry function's AST for **bare-name** calls (``ast.Call`` whose ``func`` is
an ``ast.Name``). Method-style calls like ``Projections.linear(...)`` are
``ast.Attribute`` and are *not* picked up. Re-exporting the static method as
a module-level binding (``linear = Projections.linear``) lets the entry
function call it as ``linear(...)`` so dep discovery succeeds.
"""

import pypto.language as pl

from config import BATCH, HIDDEN, K_PROJ_CHUNK, N_OUT_CHUNK


class Projections:
"""Linear-projection helpers grouped under a class namespace.

Each method is decorated with ``@pl.jit.inline`` so its body is spliced
into the caller by the ``InlineFunctions`` IR pass. ``@staticmethod`` is
layered on top so Python returns the underlying ``JITFunction`` when the
attribute is read off the class (``Projections.linear``).
"""
@pl.jit.inline
def linear(
x: pl.Tensor[[BATCH, HIDDEN], pl.BF16],
w: pl.Tensor[[HIDDEN, HIDDEN], pl.BF16],
y: pl.Out[pl.Tensor[[BATCH, HIDDEN], pl.FP32]],
):
"""``y = x @ w`` — N parallel, K reduction pipelined inside each scope."""
for n0 in pl.parallel(0, HIDDEN, N_OUT_CHUNK):
with pl.at(level=pl.Level.CORE_GROUP, name_hint="linear"):
acc = pl.create_tensor([BATCH, N_OUT_CHUNK], dtype=pl.FP32)
for kb in pl.pipeline(0, HIDDEN // K_PROJ_CHUNK, stage=2):
Comment on lines +46 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The allocation of acc at line 48 is redundant because acc is immediately reassigned in the first iteration of the kb loop (when k0 == 0) at line 54. Removing this unnecessary pl.create_tensor call saves resources and simplifies the code.

Suggested change
acc = pl.create_tensor([BATCH, N_OUT_CHUNK], dtype=pl.FP32)
for kb in pl.pipeline(0, HIDDEN // K_PROJ_CHUNK, stage=2):
for kb in pl.pipeline(0, HIDDEN // K_PROJ_CHUNK, stage=2):

k0 = kb * K_PROJ_CHUNK
Comment on lines +47 to +48
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Guard against silent K-tail truncation in the reduction loop.

At Line 49, floor-division (HIDDEN // K_PROJ_CHUNK) drops any remainder, which would silently skip part of K if constants change. Add a fail-fast check before the loop.

Proposed fix
 class Projections:
@@
     def linear(
         x: pl.Tensor[[BATCH, HIDDEN], pl.BF16],
         w: pl.Tensor[[HIDDEN, HIDDEN], pl.BF16],
         y: pl.Out[pl.Tensor[[BATCH, HIDDEN], pl.FP32]],
     ):
         """``y = x @ w`` — N parallel, K reduction pipelined inside each scope."""
+        if HIDDEN % K_PROJ_CHUNK != 0:
+            raise ValueError("HIDDEN must be divisible by K_PROJ_CHUNK to avoid dropping K tail.")
         for n0 in pl.parallel(0, HIDDEN, N_OUT_CHUNK):
             with pl.at(level=pl.Level.CORE_GROUP, name_hint="linear"):
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/advanced/cross_class_proj/proj_lib.py` around lines 49 - 50, The
loop using pl.pipeline(0, HIDDEN // K_PROJ_CHUNK, stage=2) silently drops any
remainder of HIDDEN divided by K_PROJ_CHUNK; add a fail-fast check before that
loop to ensure HIDDEN % K_PROJ_CHUNK == 0 (or raise a clear error) so the
reduction over k0 = kb * K_PROJ_CHUNK doesn't skip tail elements; locate the
check near the use of HIDDEN, K_PROJ_CHUNK and pl.pipeline in proj_lib.py and
raise a ValueError (or assert) with a descriptive message if the remainder is
non-zero.

tile_x = x[:, k0 : k0 + K_PROJ_CHUNK]
tile_w = w[k0 : k0 + K_PROJ_CHUNK, n0 : n0 + N_OUT_CHUNK]
if k0 == 0:
acc = pl.matmul(tile_x, tile_w, out_dtype=pl.FP32)
else:
acc = pl.matmul_acc(acc, tile_x, tile_w)
y = pl.assemble(y, acc, [0, n0])
return y

Loading