diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 558122479..00560e763 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -199,7 +199,7 @@ jobs: timeout 60 git clone https://github.com/hw-native-sys/pto-isa.git $GITHUB_WORKSPACE/pto-isa \ || { rm -rf $GITHUB_WORKSPACE/pto-isa; timeout 300 git clone https://gitcode.com/luohuan40/pto-isa.git $GITHUB_WORKSPACE/pto-isa; } cd $GITHUB_WORKSPACE/pto-isa - git checkout 2c607938 + git checkout 8bd3ac8f30bd237f9eaf12c142002a5cc0edb143 - name: Install simpler run: | @@ -219,7 +219,7 @@ jobs: # runs here despite the tests/st/codegen ignore: its numeric golden # compare is unstable on the CPU container's torch/BLAS build (see the # codegen-tests job), so it needs this environment. - run: pytest tests/st tests/st/codegen/torch/test_torch_codegen_paged_attention.py -v --device="$DEVICE_RANGE" --precompile-workers=128 --pto-isa-commit=2c607938 --ignore=tests/st/distributed --ignore=tests/st/codegen + run: pytest tests/st tests/st/codegen/torch/test_torch_codegen_paged_attention.py -v --device="$DEVICE_RANGE" --precompile-workers=128 --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143 --ignore=tests/st/distributed --ignore=tests/st/codegen - name: Test multi-orch L2 (isolated pytest invocation) # A ``Worker(level=2)`` device session currently leaves runtime @@ -228,12 +228,12 @@ jobs: # isolation bug). Running this file in its own pytest process # keeps the L2 leak from poisoning ``test_l3_distributed``. timeout-minutes: 3 - run: pytest tests/st/distributed/test_l2_multi_orch.py -v --device="$DEVICE_RANGE" --pto-isa-commit=2c607938 + run: pytest tests/st/distributed/test_l2_multi_orch.py -v --device="$DEVICE_RANGE" --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143 - name: Test swimlane output run: | pytest tests/st/runtime/framework_and_models/test_perf_swimlane.py \ - -v --device="$DEVICE_ID" --platform=a2a3 --enable-l2-swimlane --forked --pto-isa-commit=2c607938 + -v --device="$DEVICE_ID" --platform=a2a3 --enable-l2-swimlane --forked --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143 dist-system-tests: # No container: HCCL needs host-only env (HCCL_LOGIC_SUPERPOD_ID, plugin @@ -409,7 +409,7 @@ jobs: timeout 60 git clone https://github.com/hw-native-sys/pto-isa.git $GITHUB_WORKSPACE/pto-isa \ || { rm -rf $GITHUB_WORKSPACE/pto-isa; timeout 300 git clone https://gitcode.com/luohuan40/pto-isa.git $GITHUB_WORKSPACE/pto-isa; } cd $GITHUB_WORKSPACE/pto-isa - git checkout 2c607938 + git checkout 8bd3ac8f30bd237f9eaf12c142002a5cc0edb143 - name: Install simpler run: | @@ -490,10 +490,10 @@ jobs: # against the runtime's bundled pto-isa, which still clones from the # stale PTO-ISA/pto-isa mirror (lacks pto::Coalesce) until the runtime # submodule is bumped past simpler #806. Re-enable once that lands. - run: pytest tests/st/runtime/ops/test_assemble.py tests/st/runtime/ops/test_mscatter.py tests/st/runtime/framework_and_models/test_qwen3_decode_scope3_mixed.py tests/st/runtime/control_flow/test_dyn_orch_shape.py::TestDynOrchShapeOperations::test_dyn_orch_paged_attention -v --platform=a5sim --forked --pto-isa-commit=2c607938 -k "not TestMscatter" + run: pytest tests/st/runtime/ops/test_assemble.py tests/st/runtime/ops/test_mscatter.py tests/st/runtime/framework_and_models/test_qwen3_decode_scope3_mixed.py tests/st/runtime/control_flow/test_dyn_orch_shape.py::TestDynOrchShapeOperations::test_dyn_orch_paged_attention -v --platform=a5sim --forked --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143 -k "not TestMscatter" - name: Test A5 cross-core system tests (simulator) - run: pytest tests/st/runtime/cross_core/test_cross_core.py -v --forked --platform=a5sim --pto-isa-commit=2c607938 + run: pytest tests/st/runtime/cross_core/test_cross_core.py -v --forked --platform=a5sim --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143 - name: Test A2A3 cross-core system tests (simulator) - run: pytest tests/st/runtime/cross_core/test_cross_core.py -v --forked --platform=a2a3sim --pto-isa-commit=2c607938 + run: pytest tests/st/runtime/cross_core/test_cross_core.py -v --forked --platform=a2a3sim --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143 diff --git a/python/pypto/language/parser/type_resolver.py b/python/pypto/language/parser/type_resolver.py index 80018092a..a6d14e65f 100644 --- a/python/pypto/language/parser/type_resolver.py +++ b/python/pypto/language/parser/type_resolver.py @@ -1309,8 +1309,18 @@ def _resolve_tileview( # noqa: PLR0912 # fields matching the memory-space-aware implicit defaults; the parser must # recover them from the same rules — regardless of whether memory_space is # present, since shape-derived defaults (e.g. col_major for [N,1]) also apply. + # + # The basis MUST be the full tile shape, not valid_shape: the printer elides + # against ``GetImplicitTileView(tile_type.shape_, ...)`` (the physical tile + # shape), so the parser has to infer from the same shape. Using valid_shape + # here desynchronized the two for packed-mask tiles — e.g. a cmp/cmps result + # with physical shape [16, 8] but valid_shape [16, 1]: the printer omits the + # (row_major) blayout, while valid_shape's cols==1 made the parser fill + # col_major, so the print->parse roundtrip failed with "TileView blayout + # mismatch" (#1498). Prefer tile_shape, falling back to valid_shape only when + # the physical shape is unavailable. impl_blayout, impl_slayout, impl_fractal = _implicit_tile_view_defaults( - valid_shape if valid_shape else (tile_shape or []), memory_space + tile_shape if tile_shape else (valid_shape or []), memory_space ) return ir.TileView( valid_shape=valid_shape if valid_shape is not None else [], diff --git a/src/backend/common/pto_ops_common.cpp b/src/backend/common/pto_ops_common.cpp index b8fbc7f46..c115ef3d4 100644 --- a/src/backend/common/pto_ops_common.cpp +++ b/src/backend/common/pto_ops_common.cpp @@ -973,7 +973,14 @@ static std::string MakeScatterCodegenPTO(const CallPtr& op, codegen::CodegenBase std::ostringstream oss; oss << "pto.tscatter ins(" << src << ", " << idx; // Emit the type clause only when both annotations are present; printing one - // alone would produce malformed PTOAS (": , idx" or ": src, "). + // alone would produce malformed PTOAS (": , idx" or ": src, "). The two + // operands are typed tiles produced by the same lowering, so they should + // either both carry an annotation or (in untyped contexts) both lack one — a + // one-sided annotation signals a real codegen bug, not a valid input. + INTERNAL_CHECK_SPAN(src_type.empty() == idx_type.empty(), op->span_) + << "Internal error: tile.scatter src/indexes type annotations must both be present or both " + "absent, got src_type='" + << src_type << "', idx_type='" << idx_type << "'"; if (!src_type.empty() && !idx_type.empty()) { oss << " : " << src_type << ", " << idx_type; } @@ -988,16 +995,17 @@ static std::string MakeScatterCodegenPTO(const CallPtr& op, codegen::CodegenBase } // Helper for tile.scatter_mask (TSCATTER mask form, DPS): -// pto.tscatter ins(%src : src_ty) outs(%dst : dst_ty) -// {maskPattern = #pto.mask_pattern} +// pto.tscatter ins(%src, {maskPattern = #pto.mask_pattern} : src_ty) +// outs(%dst : dst_ty) // -// Unlike pto.tgather's mask form (which carries maskPattern inside ins()), -// pto.tscatter only accepts SSA operands in ins() — the maskPattern must be a -// trailing op attribute dict (same placement as gather_compare's cmpMode). +// The maskPattern rides *inside* ins() right after the src operand, exactly +// like pto.tgather's mask form — PTOAS parses ins() as "src, attr-dict : +// type" and rejects a bare ins(%src ...) ("expected ',' after src operand"). +// The type annotation follows the attr dict, still inside ins(). // // IR surface: 2-input op (dst, src) + mask_pattern attr; dst aliased via -// set_output_reuses_input(0). Mask form is targeted at A3 / CPU-sim style -// backends; A5 rejects it on the PTOAS side. +// set_output_reuses_input(0). Mask form is targeted at A2/A3 backends; A5 +// (Ascend950) rejects it on the PTOAS side. static std::string MakeScatterMaskCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) { auto& codegen = dynamic_cast(codegen_base); CHECK(op->args_.size() == 2) << "tile.scatter_mask requires 2 arguments (dst, src), but got " @@ -1021,7 +1029,10 @@ static std::string MakeScatterMaskCodegenPTO(const CallPtr& op, codegen::Codegen << ", input=" << input_ssa; std::ostringstream oss; - oss << "pto.tscatter ins(" << src; + // maskPattern rides inside ins() after src, then the type annotation: + // pto.tscatter ins(%src, {maskPattern = #pto.mask_pattern} : src_ty) outs(%dst : dst_ty) + oss << "pto.tscatter ins(" << src << ", {maskPattern = #pto.mask_pattern<" << mask_patterns.at(pattern) + << ">}"; if (!src_type.empty()) { oss << " : " << src_type; } @@ -1029,7 +1040,7 @@ static std::string MakeScatterMaskCodegenPTO(const CallPtr& op, codegen::Codegen if (!dst_type.empty()) { oss << " : " << dst_type; } - oss << ") {maskPattern = #pto.mask_pattern<" << mask_patterns.at(pattern) << ">}"; + oss << ")"; codegen.Emit(oss.str()); return ""; diff --git a/src/ir/op/tile_ops/scatter.cpp b/src/ir/op/tile_ops/scatter.cpp index 20053724c..5ef65dac2 100644 --- a/src/ir/op/tile_ops/scatter.cpp +++ b/src/ir/op/tile_ops/scatter.cpp @@ -19,6 +19,13 @@ * * Both ops are DPS — `dst` is the in/out buffer; the IR result aliases `dst` * via `set_output_reuses_input(...)`. There is no compare form for scatter. + * + * Duplicate-index ordering: when two index entries map to the same destination + * slot, pto.tscatter resolves the collision in ascending element order (the + * later/higher-index write wins), matching torch `scatter_`'s last-wins + * semantics along the scan axis. Callers that build flat indices + * (ConvertTensorToTileOps) and the ST reference both rely on this order; it is + * a pto.tscatter ABI guarantee, not a PyPTO-side choice. */ #include diff --git a/src/ir/transforms/op_conversion_registry.cpp b/src/ir/transforms/op_conversion_registry.cpp index c26dac916..54390319a 100644 --- a/src/ir/transforms/op_conversion_registry.cpp +++ b/src/ir/transforms/op_conversion_registry.cpp @@ -1432,7 +1432,10 @@ void OpConversionRegistry::RegisterGatherOps() { // select blend. The surrounding pass wraps the tile result in a tile.store to // the output tensor param. // -// tensor.scatter_mask: same idea, simple (input, dst) → (dst, src) re-wire. +// tensor.scatter_mask: same idea — pto.tscatter (mask form) zero-fills the whole +// dst before writing the selected columns, so it does not preserve dst either. +// We reconstruct DPS preserve with the same zeroed-scatter + mask + select +// blend, which also makes chaining two patterns into one dst sound. // ============================================================================ void OpConversionRegistry::RegisterScatterOps() { @@ -1476,6 +1479,10 @@ void OpConversionRegistry::RegisterScatterOps() { // flat_idx = index + row_base (row-broadcast add) auto& op_reg = OpRegistry::GetInstance(); auto src_rows = As(src_tile->shape_[0]); + // `cols` is the flat-layout column count of the scattered destination. + // input/output and the inner tile.scatter dst all share this width (S), + // so reading it off `input` is equivalent to the dst column count used + // in the `i * dst_cols` flat-index formula above. auto dst_cols_c = As(input_tile->shape_[1]); CHECK(src_rows && dst_cols_c) << "tensor.scatter conversion requires static src rows and dst cols for index expansion"; @@ -1486,6 +1493,26 @@ void OpConversionRegistry::RegisterScatterOps() { // for 4-byte dst, INT16 for 2/1-byte dst). const DataType idx_dtype = idx_tile->dtype_; + // INT16 flat-index range guard. For 2-byte element dtypes the flattened + // destination indices are INT16, whose largest representable value is + // 32767. The biggest index this lowering produces is + // (n-1)*cols + (cols-1) == n*cols - 1, so n*cols must stay <= 32768. + // 4-byte dtypes use INT32 indices and are effectively unbounded here. + // Without this check an oversized tile would silently overflow INT16 and + // scatter to wrong addresses instead of failing loudly. + if (idx_dtype == DataType::INT16) { + // Bound via division so the product never overflows int64_t: rows is + // capped so rows*cols stays <= 32768. cols is always > 0 here (a + // 2-byte tile has at least one column), but guard against 0 anyway. + const int64_t kMaxFlat = 32768; + const int64_t max_rows = cols == 0 ? kMaxFlat : kMaxFlat / cols; + CHECK(n <= max_rows) << "tensor.scatter with element dtype " << input_tile->dtype_.ToString() + << " uses INT16 flattened indices, but the destination is too large: rows(" + << n << ") * cols(" << cols + << ") exceeds the INT16 index range (max flat index 32767, rows <= " + << max_rows << "). Use a smaller tile or split the scatter into chunks."; + } + auto make_idx = [&](int64_t v) -> ExprPtr { return std::make_shared(v, DataType::INDEX, span); }; @@ -1611,10 +1638,79 @@ void OpConversionRegistry::RegisterScatterOps() { CHECK(args.size() == 2) << "tensor.scatter_mask conversion expects 2 args (input, dst), got " << args.size(); auto& op_reg = OpRegistry::GetInstance(); - // tile.scatter_mask signature: (dst, src) + mask_pattern attr. The - // converter's args[1] (dst) maps to dst, args[0] (input) to src. - auto tile_call = op_reg.Create("tile.scatter_mask", {args[1], args[0]}, kwargs, span); - return ConversionResult{tile_call}; + auto input_tile = As(args[0]->GetType()); + auto dst_tile = As(args[1]->GetType()); + CHECK(input_tile && dst_tile) + << "tensor.scatter_mask conversion: input/dst must be Vec tiles after bridge"; + + // pto.tscatter (mask form) zero-fills the entire dst tile before writing + // the mask-selected columns (TScatterMaskImpl calls InitUBBuffer), so it + // does NOT preserve dst's unselected columns — they read back as 0. To + // honour the DPS preserve contract (and make chaining two patterns into + // one dst sound), reconstruct preserve on the PyPTO side with the same + // zeroed-scatter + mask + select blend as the index form: + // + // scattered = scatter_mask(zeros, input) # input @selected, 0 @unselected + // mask = scatter_mask(zeros_m, ones) # 1 @selected, 0 @unselected + // pred = (mask != 0) # packed predicate + // out = sel(pred, scattered, dst) # selected→scattered, else→dst + auto in_rows = As(input_tile->shape_[0]); + auto in_cols = As(input_tile->shape_[1]); + auto dst_cols_c = As(dst_tile->shape_[1]); + CHECK(in_rows && in_cols && dst_cols_c) + << "tensor.scatter_mask conversion requires static shapes for the preserve blend"; + const int64_t b = in_rows->value_; + const int64_t c = in_cols->value_; + const int64_t dst_cols = dst_cols_c->value_; + const DataType dt = dst_tile->dtype_; + + // Mask blend dtype: a compare-friendly type within dst's element size + // (bf16 → f16) so tile.cmps is well-defined for any supported dtype. + const int dt_bytes = static_cast(dt.GetBit()) / 8; + const DataType mask_dt = (dt_bytes == 4) ? DataType(DataType::FP32) + : (dt_bytes == 2) ? DataType(DataType::FP16) + : DataType(DataType::INT8); + + auto make_idx = [&](int64_t v) -> ExprPtr { + return std::make_shared(v, DataType::INDEX, span); + }; + std::vector prologue; + auto emit = [&](const std::string& op_name, const std::vector& op_args, + const std::vector>& op_kwargs, + const std::string& name) -> VarPtr { + auto call = op_kwargs.empty() ? op_reg.Create(op_name, op_args, span) + : op_reg.Create(op_name, op_args, op_kwargs, span); + auto var = std::make_shared(name, call->GetType(), span); + prologue.push_back(std::make_shared(var, call, span)); + return var; + }; + auto make_full = [&](const DataType& fdt, int64_t rows, int64_t cols_, double v, + const std::string& name) -> VarPtr { + ExprPtr val = fdt.IsFloat() + ? ExprPtr(std::make_shared(v, fdt, span)) + : ExprPtr(std::make_shared(static_cast(v), fdt, span)); + return emit("tile.full", {MakeShapeTuple({make_idx(rows), make_idx(cols_)}, span), val}, + {{"dtype", fdt}}, name); + }; + + // scattered = input written into the mask-selected columns of a zeroed dst. + auto values_zero = make_full(dt, b, dst_cols, 0.0, "scatter_mask_values_zero"); + auto scattered = emit("tile.scatter_mask", {values_zero, args[0]}, kwargs, "scatter_mask_values"); + // mask = ones written into the same selected columns of a zeroed base. + auto mask_zero = make_full(mask_dt, b, dst_cols, 0.0, "scatter_mask_mask_zero"); + auto ones_src = make_full(mask_dt, b, c, 1.0, "scatter_mask_ones"); + auto mask = emit("tile.scatter_mask", {mask_zero, ones_src}, kwargs, "scatter_mask_mask"); + // pred = (mask != 0) → packed predicate (NE = cmp_type 1). + ExprPtr zero_scalar = mask_dt.IsFloat() ? ExprPtr(std::make_shared(0.0, mask_dt, span)) + : ExprPtr(std::make_shared(0, mask_dt, span)); + auto pred = emit("tile.cmps", {mask, zero_scalar}, {{"cmp_type", 1}}, "scatter_mask_pred"); + // tmp = TSEL scratch tile (UINT8 [1, 32]). + auto tmp = emit("tile.create", {MakeShapeTuple({make_idx(1), make_idx(32)}, span)}, + {{"dtype", DataType(DataType::UINT8)}, {"target_memory", MemorySpace::Vec}}, + "scatter_mask_sel_tmp"); + // out = sel(pred, scattered, dst, tmp): scattered @selected, dst @unselected. + auto out_call = op_reg.Create("tile.sel", {pred, scattered, args[1], tmp}, span); + return ConversionResult{std::move(prologue), out_call}; }, std::move(scatter_mask_input_reqs)); } diff --git a/tests/st/runtime/test_scatter.py b/tests/st/runtime/ops/test_scatter.py similarity index 61% rename from tests/st/runtime/test_scatter.py rename to tests/st/runtime/ops/test_scatter.py index 395ddbfa7..6870185c3 100644 --- a/tests/st/runtime/test_scatter.py +++ b/tests/st/runtime/ops/test_scatter.py @@ -53,6 +53,21 @@ A separate FP32 case feeds **repeated** indices with distinct values to pin the ascending-k last-wins ordering (the round-trip version hid this by writing equal values to repeated targets). + +**Mask form (A2/A3).** ``TestScatterMaskForm`` covers ``tensor.scatter_mask`` +(``mask_pattern=`` + ``dst``), the column-wise inverse of the mask-form +gather: each compact ``input`` row is written into the mask-selected columns of +the wider ``dst`` (``dst.cols == input.cols * stride``). The form runs on the +A2/A3 backend (``BackendType.Ascend910B``); A5 (``Ascend950``) rejects it, so +those cases are pinned to A2/A3 only. Column selection mirrors gather: P0101 +hits even columns (``0::2``), P1010 hits odd columns (``1::2``). + +The raw ``pto.tscatter`` mask instruction zero-fills the entire ``dst`` before +writing the selected columns, so the lowering reconstructs DPS preserve with the +same zeroed-scatter + mask + select blend as the index form. These cases pin +that with a **non-zero sentinel** ``dst`` (unselected columns must survive), and +the chain case writes two patterns into one ``dst`` (P0101 then P1010) so the +second scatter must preserve the first's writes — the RoPE even/odd reassembly. """ from typing import Any @@ -138,6 +153,42 @@ def _scatter_specs( ] +def _scatter_mask_specs(b: int, c: int, stride: int, dt: DataType, torch_dt: torch.dtype) -> list[TensorSpec]: + """Build the (inp, dst, output) TensorSpecs for a mask-form scatter case. + + ``inp`` is ``[B, C]`` of distinct positive values; ``dst`` is ``[B, C*stride]`` + pre-filled with a **negative sentinel** (disjoint from ``inp``) rather than + zeros, so the case actually distinguishes preserve from zero-fill: the + mask-selected columns must become ``inp`` and the unselected columns must + keep the sentinel. A zero ``dst`` could not tell the two apart. + """ + dst_cols = c * stride + return [ + TensorSpec("inp", [b, c], dt, init_value=lambda: _make_values(b, c, torch_dt)), + TensorSpec("dst", [b, dst_cols], dt, init_value=lambda: _make_base(b, dst_cols, torch_dt)), + TensorSpec("output", [b, dst_cols], dt, is_output=True), + ] + + +def _scatter_mask_chain_specs(b: int, c: int, dt: DataType, torch_dt: torch.dtype) -> list[TensorSpec]: + """Build the (even, odd, dst, output) TensorSpecs for the chained mask-scatter case. + + Two compact ``[B, C]`` inputs are interleaved into one ``[B, 2C]`` dst by + chaining two mask scatters: ``even`` into the even columns (P0101) and + ``odd`` into the odd columns (P1010). ``even`` and ``odd`` hold disjoint + positive ranges so a swapped pattern or a clobbered column is caught. ``dst`` + starts zeroed; a correct chain leaves ``dst[:, 0::2] = even`` and + ``dst[:, 1::2] = odd`` (the second scatter must preserve the first's writes). + """ + dst_cols = 2 * c + return [ + TensorSpec("even", [b, c], dt, init_value=lambda: _make_values(b, c, torch_dt)), + TensorSpec("odd", [b, c], dt, init_value=lambda: _make_values(b, c, torch_dt) + b * c), + TensorSpec("dst", [b, dst_cols], dt, init_value=lambda: torch.zeros(b, dst_cols, dtype=torch_dt)), + TensorSpec("output", [b, dst_cols], dt, is_output=True), + ] + + # --- Programs (one per element type; shapes satisfy the 32-byte row alignment) --- @@ -256,6 +307,77 @@ def main( return output +# --- Mask-form programs (A2/A3 only; A5/Ascend950 rejects pto.tscatter mask form) --- + + +@pl.program +class ScatterMaskP0101Program: + """Mask-form scatter (P0101, A2/A3): write ``inp`` into even columns of ``dst``. + + ``inp [8, 8] FP32`` expands into ``dst [8, 16] FP32`` (stride 2): the inverse + of the P0101 mask gather, so ``dst[:, 0::2] = inp``. + """ + + @pl.function(type=pl.FunctionType.Opaque) + def main( + self, + inp: pl.Tensor[[8, 8], pl.FP32], + dst: pl.Tensor[[8, 16], pl.FP32], + output: pl.Out[pl.Tensor[[8, 16], pl.FP32]], + ) -> pl.Tensor[[8, 16], pl.FP32]: + with pl.at(level=pl.Level.CORE_GROUP): + out = pl.tensor.scatter(inp, mask_pattern=pl.tile.MaskPattern.P0101, dst=dst) + output = pl.assemble(output, out, [0, 0]) + return output + + +@pl.program +class ScatterMaskP1010Program: + """Mask-form scatter (P1010, A2/A3): write ``inp`` into odd columns of ``dst``. + + ``inp [8, 8] FP32`` expands into ``dst [8, 16] FP32`` (stride 2): the inverse + of the P1010 mask gather, so ``dst[:, 1::2] = inp``. + """ + + @pl.function(type=pl.FunctionType.Opaque) + def main( + self, + inp: pl.Tensor[[8, 8], pl.FP32], + dst: pl.Tensor[[8, 16], pl.FP32], + output: pl.Out[pl.Tensor[[8, 16], pl.FP32]], + ) -> pl.Tensor[[8, 16], pl.FP32]: + with pl.at(level=pl.Level.CORE_GROUP): + out = pl.tensor.scatter(inp, mask_pattern=pl.tile.MaskPattern.P1010, dst=dst) + output = pl.assemble(output, out, [0, 0]) + return output + + +@pl.program +class ScatterMaskChainProgram: + """Chained mask-form scatter (P0101 then P1010, A2/A3): RoPE even/odd reassembly. + + Writes ``even`` into the even columns (P0101) and ``odd`` into the odd columns + (P1010) of a single ``dst``, chaining two mask scatters into one buffer — the + inverse of splitting a RoPE head into even/odd halves. Validates that the + second scatter preserves the first's writes, so ``dst[:, 0::2] = even`` and + ``dst[:, 1::2] = odd``. + """ + + @pl.function(type=pl.FunctionType.Opaque) + def main( + self, + even: pl.Tensor[[16, 32], pl.FP32], + odd: pl.Tensor[[16, 32], pl.FP32], + dst: pl.Tensor[[16, 64], pl.FP32], + output: pl.Out[pl.Tensor[[16, 64], pl.FP32]], + ) -> pl.Tensor[[16, 64], pl.FP32]: + with pl.at(level=pl.Level.CORE_GROUP): + out = pl.tensor.scatter(even, mask_pattern=pl.tile.MaskPattern.P0101, dst=dst) + out = pl.tensor.scatter(odd, mask_pattern=pl.tile.MaskPattern.P1010, dst=out) + output = pl.assemble(output, out, [0, 0]) + return output + + # --- Test cases --- @@ -358,6 +480,85 @@ def get_program(self) -> Any: return Scatter1RowProgram +class _ScatterMaskBaseTestCase(PTOTestCase): + """Base for mask-form scatter cases. Pinned to A2/A3 (Ascend910B). + + Subclasses set ``_start`` (0 for P0101, 1 for P1010) and ``_stride`` (2); + ``compute_expected`` writes ``inp`` into the ``[start::stride]`` columns of + ``dst`` while **preserving** ``dst``'s other (sentinel) columns — so the case + fails if the lowering zero-fills the unselected columns instead. + """ + + __test__ = False + _start: int = 0 + _stride: int = 2 + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + # Mask-form pto.tscatter is an A2/A3 feature; A5 (Ascend950) rejects it. + return BackendType.Ascend910B + + def compute_expected(self, tensors, params=None): + # Preserve dst's unselected (sentinel) columns; write inp into the + # mask-selected columns. Discriminates preserve from zero-fill. + out = tensors["dst"].clone() + out[:, self._start :: self._stride] = tensors["inp"] + tensors["output"][:] = out + + +class ScatterMaskP0101TestCase(_ScatterMaskBaseTestCase): + _start = 0 # P0101 selects even columns + + def get_name(self) -> str: + return "scatter_mask_p0101" + + def define_tensors(self) -> list[TensorSpec]: + return _scatter_mask_specs(8, 8, 2, DataType.FP32, torch.float32) + + def get_program(self) -> Any: + return ScatterMaskP0101Program + + +class ScatterMaskP1010TestCase(_ScatterMaskBaseTestCase): + _start = 1 # P1010 selects odd columns + + def get_name(self) -> str: + return "scatter_mask_p1010" + + def define_tensors(self) -> list[TensorSpec]: + return _scatter_mask_specs(8, 8, 2, DataType.FP32, torch.float32) + + def get_program(self) -> Any: + return ScatterMaskP1010Program + + +class ScatterMaskChainTestCase(_ScatterMaskBaseTestCase): + """Chain P0101 then P1010 into one dst (RoPE even/odd reassembly). + + Unlike the single-pattern cases, this writes two compact inputs into one + interleaved ``dst`` via two chained scatters, pinning that the second + pattern's scatter preserves the first's writes. + """ + + def get_name(self) -> str: + return "scatter_mask_chain" + + def define_tensors(self) -> list[TensorSpec]: + return _scatter_mask_chain_specs(16, 32, DataType.FP32, torch.float32) + + def get_program(self) -> Any: + return ScatterMaskChainProgram + + def compute_expected(self, tensors, params=None): + b, dst_cols = tensors["output"].shape + out = torch.zeros(b, dst_cols, dtype=tensors["even"].dtype) + out[:, 0::2] = tensors["even"] # P0101 → even columns + out[:, 1::2] = tensors["odd"] # P1010 → odd columns + tensors["output"][:] = out + + # --- Tests --- @@ -400,5 +601,31 @@ def test_scatter_1row(self, test_runner, platform): assert result.passed, f"Test failed: {result.error}" +class TestScatterMaskForm: + """Mask-form row scatter — A2/A3 only (A5/Ascend950 rejects the mask form). + + Each compact ``inp`` row is written into the mask-selected columns of the + wider ``dst`` (``dst.cols == inp.cols * stride``); column selection mirrors + the mask gather (P0101 → even, P1010 → odd). The chain case writes two inputs + into one ``dst`` (P0101 then P1010) to pin that the second scatter preserves + the first's writes — the RoPE even/odd reassembly tail. + """ + + @pytest.mark.parametrize("platform", PLATFORMS) + def test_scatter_mask_p0101(self, test_runner, platform): + result = test_runner.run(ScatterMaskP0101TestCase(platform=platform)) + assert result.passed, f"Test failed: {result.error}" + + @pytest.mark.parametrize("platform", PLATFORMS) + def test_scatter_mask_p1010(self, test_runner, platform): + result = test_runner.run(ScatterMaskP1010TestCase(platform=platform)) + assert result.passed, f"Test failed: {result.error}" + + @pytest.mark.parametrize("platform", PLATFORMS) + def test_scatter_mask_chain(self, test_runner, platform): + result = test_runner.run(ScatterMaskChainTestCase(platform=platform)) + assert result.passed, f"Test failed: {result.error}" + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/ut/codegen/test_pto_codegen_ops.py b/tests/ut/codegen/test_pto_codegen_ops.py index 7572ff60c..7310dc063 100644 --- a/tests/ut/codegen/test_pto_codegen_ops.py +++ b/tests/ut/codegen/test_pto_codegen_ops.py @@ -1940,6 +1940,12 @@ def kernel( assert "ins(" in line and "outs(" in line, ( f"pto.tscatter mask form must use the ins(...) outs(...) DPS form, got:\n{line}" ) + # maskPattern must ride inside ins() after src (like pto.tgather), not as + # a trailing attr after outs() — PTOAS rejects a bare ins(%src ...) with + # "expected ',' after src operand". + assert line.index("maskPattern") < line.index("outs("), ( + f"maskPattern must appear inside ins(...) before outs(...), got:\n{line}" + ) if __name__ == "__main__": diff --git a/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py b/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py index f949733dd..b37304b8d 100644 --- a/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py +++ b/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py @@ -2430,13 +2430,6 @@ def test_gather_mask_conversion(self): class TestConvertScatterOp: """Test conversion of tensor.scatter (rank-2 dim=-1 MVP) and tensor.scatter_mask.""" - @pytest.mark.skip( - reason="Blocked by a pre-existing cmp/cmps round-trip gap: the scatter preserve " - "blend emits tile.cmps, whose packed-mask result TileView loses its blayout on " - "print->parse, so the autouse RoundtripInstrument fails structural equality " - "(same failure as any tensor.cmp conversion). Assertions below are correct and " - "ready once the packed-mask TileView blayout round-trips." - ) def test_scatter_conversion(self): """tensor.scatter -> tile.load(input/index/src) + flat-index build + tile.scatter.""" @@ -2462,6 +2455,11 @@ def main( out: pl.Tensor[[16, 8], pl.FP32] = self.main_incore_0(inp, idx, src) return out + # Runs under the autouse roundtrip instrument. The preserve blend emits + # tile.cmps whose packed-mask result has valid_shape [N, 1] on a wider + # physical tile; with #1498 fixed (parser now infers the implicit blayout + # from the physical tile shape, not valid_shape) the print->parse roundtrip + # holds, so this conversion no longer needs to be skipped. After = passes.convert_tensor_to_tile_ops()(Before) after_src = After.as_python()