Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ jobs:
timeout 60 git clone https://github.com/hw-native-sys/pto-isa.git $GITHUB_WORKSPACE/pto-isa \
|| { rm -rf $GITHUB_WORKSPACE/pto-isa; timeout 300 git clone https://gitcode.com/luohuan40/pto-isa.git $GITHUB_WORKSPACE/pto-isa; }
cd $GITHUB_WORKSPACE/pto-isa
git checkout 2c607938
git checkout 8bd3ac8f30bd237f9eaf12c142002a5cc0edb143

- name: Install simpler
run: |
Expand All @@ -219,7 +219,7 @@ jobs:
# runs here despite the tests/st/codegen ignore: its numeric golden
# compare is unstable on the CPU container's torch/BLAS build (see the
# codegen-tests job), so it needs this environment.
run: pytest tests/st tests/st/codegen/torch/test_torch_codegen_paged_attention.py -v --device="$DEVICE_RANGE" --precompile-workers=128 --pto-isa-commit=2c607938 --ignore=tests/st/distributed --ignore=tests/st/codegen
run: pytest tests/st tests/st/codegen/torch/test_torch_codegen_paged_attention.py -v --device="$DEVICE_RANGE" --precompile-workers=128 --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143 --ignore=tests/st/distributed --ignore=tests/st/codegen

- name: Test multi-orch L2 (isolated pytest invocation)
# A ``Worker(level=2)`` device session currently leaves runtime
Expand All @@ -228,12 +228,12 @@ jobs:
# isolation bug). Running this file in its own pytest process
# keeps the L2 leak from poisoning ``test_l3_distributed``.
timeout-minutes: 3
run: pytest tests/st/distributed/test_l2_multi_orch.py -v --device="$DEVICE_RANGE" --pto-isa-commit=2c607938
run: pytest tests/st/distributed/test_l2_multi_orch.py -v --device="$DEVICE_RANGE" --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143

- name: Test swimlane output
run: |
pytest tests/st/runtime/framework_and_models/test_perf_swimlane.py \
-v --device="$DEVICE_ID" --platform=a2a3 --enable-l2-swimlane --forked --pto-isa-commit=2c607938
-v --device="$DEVICE_ID" --platform=a2a3 --enable-l2-swimlane --forked --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143

dist-system-tests:
# No container: HCCL needs host-only env (HCCL_LOGIC_SUPERPOD_ID, plugin
Expand Down Expand Up @@ -409,7 +409,7 @@ jobs:
timeout 60 git clone https://github.com/hw-native-sys/pto-isa.git $GITHUB_WORKSPACE/pto-isa \
|| { rm -rf $GITHUB_WORKSPACE/pto-isa; timeout 300 git clone https://gitcode.com/luohuan40/pto-isa.git $GITHUB_WORKSPACE/pto-isa; }
cd $GITHUB_WORKSPACE/pto-isa
git checkout 2c607938
git checkout 8bd3ac8f30bd237f9eaf12c142002a5cc0edb143

- name: Install simpler
run: |
Expand Down Expand Up @@ -490,10 +490,10 @@ jobs:
# against the runtime's bundled pto-isa, which still clones from the
# stale PTO-ISA/pto-isa mirror (lacks pto::Coalesce) until the runtime
# submodule is bumped past simpler #806. Re-enable once that lands.
run: pytest tests/st/runtime/ops/test_assemble.py tests/st/runtime/ops/test_mscatter.py tests/st/runtime/framework_and_models/test_qwen3_decode_scope3_mixed.py tests/st/runtime/control_flow/test_dyn_orch_shape.py::TestDynOrchShapeOperations::test_dyn_orch_paged_attention -v --platform=a5sim --forked --pto-isa-commit=2c607938 -k "not TestMscatter"
run: pytest tests/st/runtime/ops/test_assemble.py tests/st/runtime/ops/test_mscatter.py tests/st/runtime/framework_and_models/test_qwen3_decode_scope3_mixed.py tests/st/runtime/control_flow/test_dyn_orch_shape.py::TestDynOrchShapeOperations::test_dyn_orch_paged_attention -v --platform=a5sim --forked --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143 -k "not TestMscatter"

- name: Test A5 cross-core system tests (simulator)
run: pytest tests/st/runtime/cross_core/test_cross_core.py -v --forked --platform=a5sim --pto-isa-commit=2c607938
run: pytest tests/st/runtime/cross_core/test_cross_core.py -v --forked --platform=a5sim --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143

- name: Test A2A3 cross-core system tests (simulator)
run: pytest tests/st/runtime/cross_core/test_cross_core.py -v --forked --platform=a2a3sim --pto-isa-commit=2c607938
run: pytest tests/st/runtime/cross_core/test_cross_core.py -v --forked --platform=a2a3sim --pto-isa-commit=8bd3ac8f30bd237f9eaf12c142002a5cc0edb143
12 changes: 11 additions & 1 deletion python/pypto/language/parser/type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,8 +1309,18 @@ def _resolve_tileview( # noqa: PLR0912
# fields matching the memory-space-aware implicit defaults; the parser must
# recover them from the same rules — regardless of whether memory_space is
# present, since shape-derived defaults (e.g. col_major for [N,1]) also apply.
#
# The basis MUST be the full tile shape, not valid_shape: the printer elides
# against ``GetImplicitTileView(tile_type.shape_, ...)`` (the physical tile
# shape), so the parser has to infer from the same shape. Using valid_shape
# here desynchronized the two for packed-mask tiles — e.g. a cmp/cmps result
# with physical shape [16, 8] but valid_shape [16, 1]: the printer omits the
# (row_major) blayout, while valid_shape's cols==1 made the parser fill
# col_major, so the print->parse roundtrip failed with "TileView blayout
# mismatch" (#1498). Prefer tile_shape, falling back to valid_shape only when
# the physical shape is unavailable.
impl_blayout, impl_slayout, impl_fractal = _implicit_tile_view_defaults(
valid_shape if valid_shape else (tile_shape or []), memory_space
tile_shape if tile_shape else (valid_shape or []), memory_space
)
return ir.TileView(
valid_shape=valid_shape if valid_shape is not None else [],
Expand Down
31 changes: 21 additions & 10 deletions src/backend/common/pto_ops_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,14 @@ static std::string MakeScatterCodegenPTO(const CallPtr& op, codegen::CodegenBase
std::ostringstream oss;
oss << "pto.tscatter ins(" << src << ", " << idx;
// Emit the type clause only when both annotations are present; printing one
// alone would produce malformed PTOAS (": , idx" or ": src, ").
// alone would produce malformed PTOAS (": , idx" or ": src, "). The two
// operands are typed tiles produced by the same lowering, so they should
// either both carry an annotation or (in untyped contexts) both lack one — a
// one-sided annotation signals a real codegen bug, not a valid input.
INTERNAL_CHECK_SPAN(src_type.empty() == idx_type.empty(), op->span_)
<< "Internal error: tile.scatter src/indexes type annotations must both be present or both "
"absent, got src_type='"
<< src_type << "', idx_type='" << idx_type << "'";
if (!src_type.empty() && !idx_type.empty()) {
oss << " : " << src_type << ", " << idx_type;
}
Expand All @@ -988,16 +995,17 @@ static std::string MakeScatterCodegenPTO(const CallPtr& op, codegen::CodegenBase
}

// Helper for tile.scatter_mask (TSCATTER mask form, DPS):
// pto.tscatter ins(%src : src_ty) outs(%dst : dst_ty)
// {maskPattern = #pto.mask_pattern<Pxxxx>}
// pto.tscatter ins(%src, {maskPattern = #pto.mask_pattern<Pxxxx>} : src_ty)
// outs(%dst : dst_ty)
//
// Unlike pto.tgather's mask form (which carries maskPattern inside ins()),
// pto.tscatter only accepts SSA operands in ins() — the maskPattern must be a
// trailing op attribute dict (same placement as gather_compare's cmpMode).
// The maskPattern rides *inside* ins() right after the src operand, exactly
// like pto.tgather's mask form — PTOAS parses ins() as "src, attr-dict :
// type" and rejects a bare ins(%src ...) ("expected ',' after src operand").
// The type annotation follows the attr dict, still inside ins().
//
// IR surface: 2-input op (dst, src) + mask_pattern attr; dst aliased via
// set_output_reuses_input(0). Mask form is targeted at A3 / CPU-sim style
// backends; A5 rejects it on the PTOAS side.
// set_output_reuses_input(0). Mask form is targeted at A2/A3 backends; A5
// (Ascend950) rejects it on the PTOAS side.
static std::string MakeScatterMaskCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) {
auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base);
CHECK(op->args_.size() == 2) << "tile.scatter_mask requires 2 arguments (dst, src), but got "
Expand All @@ -1021,15 +1029,18 @@ static std::string MakeScatterMaskCodegenPTO(const CallPtr& op, codegen::Codegen
<< ", input=" << input_ssa;

std::ostringstream oss;
oss << "pto.tscatter ins(" << src;
// maskPattern rides inside ins() after src, then the type annotation:
// pto.tscatter ins(%src, {maskPattern = #pto.mask_pattern<Pxxxx>} : src_ty) outs(%dst : dst_ty)
oss << "pto.tscatter ins(" << src << ", {maskPattern = #pto.mask_pattern<" << mask_patterns.at(pattern)
<< ">}";
if (!src_type.empty()) {
oss << " : " << src_type;
}
oss << ") outs(" << dst;
if (!dst_type.empty()) {
oss << " : " << dst_type;
}
oss << ") {maskPattern = #pto.mask_pattern<" << mask_patterns.at(pattern) << ">}";
oss << ")";

codegen.Emit(oss.str());
return "";
Expand Down
7 changes: 7 additions & 0 deletions src/ir/op/tile_ops/scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
*
* Both ops are DPS — `dst` is the in/out buffer; the IR result aliases `dst`
* via `set_output_reuses_input(...)`. There is no compare form for scatter.
*
* Duplicate-index ordering: when two index entries map to the same destination
* slot, pto.tscatter resolves the collision in ascending element order (the
* later/higher-index write wins), matching torch `scatter_`'s last-wins
* semantics along the scan axis. Callers that build flat indices
* (ConvertTensorToTileOps) and the ST reference both rely on this order; it is
* a pto.tscatter ABI guarantee, not a PyPTO-side choice.
*/

#include <any>
Expand Down
106 changes: 101 additions & 5 deletions src/ir/transforms/op_conversion_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1432,7 +1432,10 @@ void OpConversionRegistry::RegisterGatherOps() {
// select blend. The surrounding pass wraps the tile result in a tile.store to
// the output tensor param.
//
// tensor.scatter_mask: same idea, simple (input, dst) → (dst, src) re-wire.
// tensor.scatter_mask: same idea — pto.tscatter (mask form) zero-fills the whole
// dst before writing the selected columns, so it does not preserve dst either.
// We reconstruct DPS preserve with the same zeroed-scatter + mask + select
// blend, which also makes chaining two patterns into one dst sound.
// ============================================================================

void OpConversionRegistry::RegisterScatterOps() {
Expand Down Expand Up @@ -1476,6 +1479,10 @@ void OpConversionRegistry::RegisterScatterOps() {
// flat_idx = index + row_base (row-broadcast add)
auto& op_reg = OpRegistry::GetInstance();
auto src_rows = As<ConstInt>(src_tile->shape_[0]);
// `cols` is the flat-layout column count of the scattered destination.
// input/output and the inner tile.scatter dst all share this width (S),
// so reading it off `input` is equivalent to the dst column count used
// in the `i * dst_cols` flat-index formula above.
auto dst_cols_c = As<ConstInt>(input_tile->shape_[1]);
CHECK(src_rows && dst_cols_c)
<< "tensor.scatter conversion requires static src rows and dst cols for index expansion";
Expand All @@ -1486,6 +1493,26 @@ void OpConversionRegistry::RegisterScatterOps() {
// for 4-byte dst, INT16 for 2/1-byte dst).
const DataType idx_dtype = idx_tile->dtype_;

// INT16 flat-index range guard. For 2-byte element dtypes the flattened
// destination indices are INT16, whose largest representable value is
// 32767. The biggest index this lowering produces is
// (n-1)*cols + (cols-1) == n*cols - 1, so n*cols must stay <= 32768.
// 4-byte dtypes use INT32 indices and are effectively unbounded here.
// Without this check an oversized tile would silently overflow INT16 and
// scatter to wrong addresses instead of failing loudly.
if (idx_dtype == DataType::INT16) {
// Bound via division so the product never overflows int64_t: rows is
// capped so rows*cols stays <= 32768. cols is always > 0 here (a
// 2-byte tile has at least one column), but guard against 0 anyway.
const int64_t kMaxFlat = 32768;
const int64_t max_rows = cols == 0 ? kMaxFlat : kMaxFlat / cols;
CHECK(n <= max_rows) << "tensor.scatter with element dtype " << input_tile->dtype_.ToString()
<< " uses INT16 flattened indices, but the destination is too large: rows("
<< n << ") * cols(" << cols
<< ") exceeds the INT16 index range (max flat index 32767, rows <= "
<< max_rows << "). Use a smaller tile or split the scatter into chunks.";
}

auto make_idx = [&](int64_t v) -> ExprPtr {
return std::make_shared<ConstInt>(v, DataType::INDEX, span);
};
Expand Down Expand Up @@ -1611,10 +1638,79 @@ void OpConversionRegistry::RegisterScatterOps() {
CHECK(args.size() == 2) << "tensor.scatter_mask conversion expects 2 args (input, dst), got "
<< args.size();
auto& op_reg = OpRegistry::GetInstance();
// tile.scatter_mask signature: (dst, src) + mask_pattern attr. The
// converter's args[1] (dst) maps to dst, args[0] (input) to src.
auto tile_call = op_reg.Create("tile.scatter_mask", {args[1], args[0]}, kwargs, span);
return ConversionResult{tile_call};
auto input_tile = As<TileType>(args[0]->GetType());
auto dst_tile = As<TileType>(args[1]->GetType());
CHECK(input_tile && dst_tile)
<< "tensor.scatter_mask conversion: input/dst must be Vec tiles after bridge";

// pto.tscatter (mask form) zero-fills the entire dst tile before writing
// the mask-selected columns (TScatterMaskImpl calls InitUBBuffer), so it
// does NOT preserve dst's unselected columns — they read back as 0. To
// honour the DPS preserve contract (and make chaining two patterns into
// one dst sound), reconstruct preserve on the PyPTO side with the same
// zeroed-scatter + mask + select blend as the index form:
//
// scattered = scatter_mask(zeros, input) # input @selected, 0 @unselected
// mask = scatter_mask(zeros_m, ones) # 1 @selected, 0 @unselected
// pred = (mask != 0) # packed predicate
// out = sel(pred, scattered, dst) # selected→scattered, else→dst
auto in_rows = As<ConstInt>(input_tile->shape_[0]);
auto in_cols = As<ConstInt>(input_tile->shape_[1]);
auto dst_cols_c = As<ConstInt>(dst_tile->shape_[1]);
CHECK(in_rows && in_cols && dst_cols_c)
<< "tensor.scatter_mask conversion requires static shapes for the preserve blend";
const int64_t b = in_rows->value_;
const int64_t c = in_cols->value_;
const int64_t dst_cols = dst_cols_c->value_;
const DataType dt = dst_tile->dtype_;

// Mask blend dtype: a compare-friendly type within dst's element size
// (bf16 → f16) so tile.cmps is well-defined for any supported dtype.
const int dt_bytes = static_cast<int>(dt.GetBit()) / 8;
const DataType mask_dt = (dt_bytes == 4) ? DataType(DataType::FP32)
: (dt_bytes == 2) ? DataType(DataType::FP16)
: DataType(DataType::INT8);

auto make_idx = [&](int64_t v) -> ExprPtr {
return std::make_shared<ConstInt>(v, DataType::INDEX, span);
};
std::vector<StmtPtr> prologue;
auto emit = [&](const std::string& op_name, const std::vector<ExprPtr>& op_args,
const std::vector<std::pair<std::string, std::any>>& op_kwargs,
const std::string& name) -> VarPtr {
auto call = op_kwargs.empty() ? op_reg.Create(op_name, op_args, span)
: op_reg.Create(op_name, op_args, op_kwargs, span);
auto var = std::make_shared<Var>(name, call->GetType(), span);
prologue.push_back(std::make_shared<AssignStmt>(var, call, span));
return var;
};
auto make_full = [&](const DataType& fdt, int64_t rows, int64_t cols_, double v,
const std::string& name) -> VarPtr {
ExprPtr val = fdt.IsFloat()
? ExprPtr(std::make_shared<ConstFloat>(v, fdt, span))
: ExprPtr(std::make_shared<ConstInt>(static_cast<int64_t>(v), fdt, span));
return emit("tile.full", {MakeShapeTuple({make_idx(rows), make_idx(cols_)}, span), val},
{{"dtype", fdt}}, name);
};

// scattered = input written into the mask-selected columns of a zeroed dst.
auto values_zero = make_full(dt, b, dst_cols, 0.0, "scatter_mask_values_zero");
auto scattered = emit("tile.scatter_mask", {values_zero, args[0]}, kwargs, "scatter_mask_values");
// mask = ones written into the same selected columns of a zeroed base.
auto mask_zero = make_full(mask_dt, b, dst_cols, 0.0, "scatter_mask_mask_zero");
auto ones_src = make_full(mask_dt, b, c, 1.0, "scatter_mask_ones");
auto mask = emit("tile.scatter_mask", {mask_zero, ones_src}, kwargs, "scatter_mask_mask");
// pred = (mask != 0) → packed predicate (NE = cmp_type 1).
ExprPtr zero_scalar = mask_dt.IsFloat() ? ExprPtr(std::make_shared<ConstFloat>(0.0, mask_dt, span))
: ExprPtr(std::make_shared<ConstInt>(0, mask_dt, span));
auto pred = emit("tile.cmps", {mask, zero_scalar}, {{"cmp_type", 1}}, "scatter_mask_pred");
// tmp = TSEL scratch tile (UINT8 [1, 32]).
auto tmp = emit("tile.create", {MakeShapeTuple({make_idx(1), make_idx(32)}, span)},
{{"dtype", DataType(DataType::UINT8)}, {"target_memory", MemorySpace::Vec}},
"scatter_mask_sel_tmp");
// out = sel(pred, scattered, dst, tmp): scattered @selected, dst @unselected.
auto out_call = op_reg.Create("tile.sel", {pred, scattered, args[1], tmp}, span);
return ConversionResult{std::move(prologue), out_call};
},
std::move(scatter_mask_input_reqs));
}
Expand Down
Loading
Loading