diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index a32de72f7..b2510b234 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -150,9 +150,10 @@ A logical partition (slice) of a `tensor_view`. Holds shape and stride informati | `v_row` | `int64` or `?` | Valid row count | | `v_col` | `int64` or `?` | Valid column count | | `blayout` | `BLayout` mnemonic | Base layout (`row_major` / `col_major`) | -| `slayout` | `SLayout` mnemonic | Secondary layout (`none_box` / `row_major` / `col_major`) | +| `slayout` | `SLayout` mnemonic | Secondary layout (`none_box` / `row_major` / `col_major` / `nc1hwc0` / `ndc1hwc0`) | | `fractal` | `int32` | Fractal size | | `pad` | `PadValue` mnemonic or integer literal | Padding policy/value selector (tests commonly use `pad=0`) | +| `fmap_h/pad_l/.../repeat_mode/transpose` | `int64` literals (optional) | Extra img2col metadata used when `slayout=nc1hwc0` or `ndc1hwc0` | Here, `?` denotes a dynamic symbol resolved at runtime. @@ -160,6 +161,20 @@ For `dtype=!pto.f4E1M2x2` and `dtype=!pto.f4E2M1x2`, the `rows`/`cols` and `v_row`/`v_col` values are physical packed extents. In other words, the packed dimension counts FP4 pairs stored per byte, not logical scalar FP4 elements. +For A5 img2col authoring, PTO IR does not introduce a dedicated `ConvTile` +type. Instead, a MAT `tile_buf` with `slayout=nc1hwc0` or `slayout=ndc1hwc0` +plus img2col metadata fields represents the ConvTile surface: + +- `fmap_h`, `fmap_w` +- `pad_l`, `pad_r`, `pad_t`, `pad_b` +- `stride_h`, `stride_w` +- `dilation_h`, `dilation_w` +- `filter_h`, `filter_w` +- `channel_size` +- `dst_stride`, `dst_m_position` +- `repeat_stride`, `repeat_time`, `repeat_mode` +- `transpose` + **Syntax:** ```mlir !pto.tile_buf @@ -267,12 +282,16 @@ Composite attribute and component enums for tile buffer configuration. | Parameter | Type | Description | |-----------|------|-------------| | `bLayout` | `BLayoutAttr` | Base layout (RowMajor / ColMajor) | -| `sLayout` | `SLayoutAttr` | Secondary layout (NoneBox / RowMajor / ColMajor) | +| `sLayout` | `SLayoutAttr` | Secondary layout (NoneBox / RowMajor / ColMajor / NC1HWC0 / NDC1HWC0) | | `sFractalSize` | `IntegerAttr (i32)` | Secondary fractal size | | `pad` | `PadValueAttr` | Pad value policy | +| `img2colConfig` | `DictionaryAttr` (optional) | Lowered img2col metadata dictionary for ConvTile-like MAT tiles **Syntax:** `#pto.tile_buf_config` +For img2col-lowered MAT tiles, `img2colConfig` is carried as +`img2col_config={...}` inside `#pto.tile_buf_config<...>`. + **BLayout** (Base layout): | Value | Int | Mnemonic | @@ -287,6 +306,8 @@ Composite attribute and component enums for tile buffer configuration. | `NoneBox` | 0 | `none_box` | | `RowMajor` | 1 | `row_major` | | `ColMajor` | 2 | `col_major` | +| `NC1HWC0` | 3 | `nc1hwc0` | +| `NDC1HWC0` | 4 | `ndc1hwc0` | **PadValue** (Pad value policy): @@ -299,7 +320,24 @@ Composite attribute and component enums for tile buffer configuration. --- -### 3.5 Layout +### 3.5 SetFmatrixMode + +Selects whether img2col uses the A-register bank or B-register bank, and +whether the `TIMG2COL` wrapper auto-programs the auxiliary fmatrix/repeat/padding +registers. + +| Value | Int | Mnemonic | +|-------|-----|----------| +| `FMATRIX_A_AUTO` | 0 | `fmatrix_a_auto` | +| `FMATRIX_B_AUTO` | 1 | `fmatrix_b_auto` | +| `FMATRIX_A_MANUAL` | 2 | `fmatrix_a_manual` | +| `FMATRIX_B_MANUAL` | 3 | `fmatrix_b_manual` | + +**Attribute syntax:** `#pto.set_fmatrix_mode` + +--- + +### 3.6 Layout Global tensor layout inference for [`tensor_view` (Section 2.3)](#23-ptotensor_viewd0-x-d1-x-elementtype)/[`partition_tensor_view` (Section 2.4)](#24-ptopartition_tensor_viewd0-x-d1-x-elementtype). Tile buffers additionally use **Tile Buf config** (see 3.4) to describe physical/fractal layout. @@ -7390,6 +7428,116 @@ pto.tfillpad_inplace ins(%tile : !pto.tile_buf +``` + +--- + +##### `pto.set_img2col_padding` - Program Img2col Padding State + +**Summary:** Programs the img2col padding register from ConvTile-like MAT tile metadata. + +**Arguments:** Same as `pto.set_img2col_fmatrix`. + +**Results:** None. + +**Constraints & Verification:** + +- Same structural constraints as `pto.set_img2col_fmatrix`. +- Current PTOAS implementation accepts `pad=null` and `pad=zero`; `pad=zero` lowers to `set_padding(0)` / `set_padding_b(0)`. + +**Hardware Mapping:** + +- EmitC lowers to `set_padding` / `set_padding_b` via a helper wrapper. +- Executes on the **FIX pipeline** (`PIPE_FIX`). + +--- + +##### `pto.set_img2col_repeat` - Program Img2col Repeat State + +**Summary:** Programs the img2col repeat register from ConvTile-like MAT tile metadata. + +**Arguments:** Same as `pto.set_img2col_fmatrix`. + +**Results:** None. + +**Constraints & Verification:** Same structural constraints as `pto.set_img2col_fmatrix`. + +**Hardware Mapping:** + +- EmitC lowers to `set_l3d_rpt` / `set_l3d_rpt_b` via a helper wrapper. +- Executes on the **FIX pipeline** (`PIPE_FIX`). + +--- + +##### `pto.timg2col` - Convert ConvTile to Left Tile + +**Summary:** Performs the img2col transform from a ConvTile-like MAT tile into a LEFT tile suitable for cube/matmul consumption. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `src` | `pto.tile_buf` | Source MAT tile with `slayout=nc1hwc0` or `ndc1hwc0` and img2col metadata | +| `dst` | `pto.tile_buf` | Destination LEFT tile | +| `posM` | `i32` attribute | Optional M-position, default `0` | +| `posK` | `i32` attribute | Optional K-position, default `0` | +| `fmatrixMode` | `SetFmatrixModeAttr` | Optional mode controlling A/B bank selection and auto/manual helper behavior | + +**Results:** None. Writes into `dst` via DPS pattern. + +**Constraints & Verification:** + +- A5 only. +- `src.loc` must be `mat`; `dst.loc` must be `left`. +- `src.blayout` must be `row_major`. +- `src.slayout` must be `nc1hwc0` or `ndc1hwc0`. +- `dst.blayout` must be `col_major`; `dst.slayout` must be `row_major`. +- `src` and `dst` element types must match and must be one of `i8/i16/i32/f16/bf16/f32`. + +**Hardware Mapping:** + +- EmitC lowers to `TIMG2COL(dst, src, posM, posK)`. +- The actual hardware instruction is `img2colv2_cbuf_to_ca`. +- Executes on the **MTE1 pipeline** (`PIPE_MTE1`). + +**Basic Example:** + +```mlir +pto.timg2col ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + ### 4.11 Sorting Operations ##### `pto.tsort32` - Sort Fixed 32-Element Blocks diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index fc1e0ca2f..ed5be87fa 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -598,7 +598,9 @@ def PTO_BLayoutAttr : PTO_Attr<"BLayout", "blayout"> { def PTO_SLayout_Enum : PTO_I32Enum<"SLayout", "Secondary layout", [ I32EnumAttrCase<"NoneBox", 0, "none_box">, I32EnumAttrCase<"RowMajor", 1, "row_major">, - I32EnumAttrCase<"ColMajor", 2, "col_major"> + I32EnumAttrCase<"ColMajor", 2, "col_major">, + I32EnumAttrCase<"NC1HWC0", 3, "nc1hwc0">, + I32EnumAttrCase<"NDC1HWC0", 4, "ndc1hwc0"> ]>; def PTO_SLayoutAttr : PTO_Attr<"SLayout", "slayout"> { @@ -629,6 +631,18 @@ def PTO_CompactModeAttr : PTO_Attr<"CompactMode", "compact_mode"> { let assemblyFormat = "`<` params `>`"; } +def PTO_SetFmatrixMode_Enum : PTO_I32Enum<"SetFmatrixMode", "Img2col fmatrix programming mode", [ + I32EnumAttrCase<"FMATRIX_A_AUTO", 0, "fmatrix_a_auto">, + I32EnumAttrCase<"FMATRIX_B_AUTO", 1, "fmatrix_b_auto">, + I32EnumAttrCase<"FMATRIX_A_MANUAL", 2, "fmatrix_a_manual">, + I32EnumAttrCase<"FMATRIX_B_MANUAL", 3, "fmatrix_b_manual"> + ]>; + +def PTO_SetFmatrixModeAttr : PTO_Attr<"SetFmatrixMode", "set_fmatrix_mode"> { + let parameters = (ins EnumParameter:$value); + let assemblyFormat = "`<` params `>`"; +} + // ---------- tile_buf_config (NO b_fractal / s_fractal) ---------- def TileBufConfigAttr : AttrDef { let mnemonic = "tile_buf_config"; @@ -637,7 +651,8 @@ def TileBufConfigAttr : AttrDef { "SLayoutAttr":$sLayout, "mlir::IntegerAttr":$sFractalSize, // i32 "PadValueAttr":$pad, - "CompactModeAttr":$compactMode + "CompactModeAttr":$compactMode, + "mlir::Attribute":$img2colConfig ); let hasCustomAssemblyFormat = 1; @@ -649,19 +664,33 @@ def TileBufConfigAttr : AttrDef { "mlir::IntegerAttr":$sFractalSize, "PadValueAttr":$pad, "CompactModeAttr":$compactMode + ), [{ + return Base::get($_ctxt, bLayout, sLayout, sFractalSize, pad, compactMode, + mlir::Attribute()); + }]>, + AttrBuilder<(ins + "BLayoutAttr":$bLayout, + "SLayoutAttr":$sLayout, + "mlir::IntegerAttr":$sFractalSize, + "PadValueAttr":$pad, + "CompactModeAttr":$compactMode, + "mlir::Attribute":$img2colConfig )> ]; let extraClassDeclaration = [{ static TileBufConfigAttr getDefault(MLIRContext *ctx); bool isDefault() const; + mlir::DictionaryAttr getImg2colConfigDict() const; + bool hasImg2colConfig() const; static LogicalResult verify(function_ref emitError, mlir::Attribute bLayout, mlir::Attribute sLayout, mlir::IntegerAttr sFractalSize, mlir::Attribute pad, - mlir::Attribute compactMode); + mlir::Attribute compactMode, + mlir::Attribute img2colConfig); }]; } diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 344b399a6..9fdbc7df9 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -4225,6 +4225,106 @@ def TFillPadOp : PTO_TOp<"tfillpad", [ }]; } +def SetImg2colFmatrixOp : PTO_Op<"set_img2col_fmatrix", [ + OpPipeInterface +]> { + let summary = "Program the img2col fmatrix state from tile metadata"; + + let arguments = (ins + PTODpsType:$config, + OptionalAttr:$fmatrixMode + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $config attr-dict `:` qualified(type($config)) + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_FIX; } + }]; +} + +def SetImg2colPaddingOp : PTO_Op<"set_img2col_padding", [ + OpPipeInterface +]> { + let summary = "Program the img2col padding value from tile metadata"; + + let arguments = (ins + PTODpsType:$config, + OptionalAttr:$fmatrixMode + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $config attr-dict `:` qualified(type($config)) + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_FIX; } + }]; +} + +def SetImg2colRepeatOp : PTO_Op<"set_img2col_repeat", [ + OpPipeInterface +]> { + let summary = "Program the img2col repeat state from tile metadata"; + + let arguments = (ins + PTODpsType:$config, + OptionalAttr:$fmatrixMode + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $config attr-dict `:` qualified(type($config)) + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_FIX; } + }]; +} + +def TImg2colOp : PTO_TOp<"timg2col", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Perform img2col from MAT source into LEFT destination"; + + let arguments = (ins + PTODpsType:$src, + PTODpsType:$dst, + DefaultValuedAttr:$posM, + DefaultValuedAttr:$posK, + OptionalAttr:$fmatrixMode + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $src `:` qualified(type($src)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_MTE1; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + def TFillPadExpandOp : PTO_TOp<"tfillpad_expand", [ PTO_DpsInitOpInterface, OpPipeInterface, diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index 665014fc3..b73e00df8 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -212,10 +212,12 @@ def TileBufType : TypeDef { int32_t getSFractalSizeI32() const; mlir::Attribute getPadValueAttr() const; mlir::Attribute getCompactModeAttr() const; + mlir::DictionaryAttr getImg2colConfigAttr() const; + bool hasImg2colConfig() const; // 如果你仍然想要“数值枚举”,就提供 int getter(不会依赖 enum 类型) int32_t getBLayoutValueI32() const; // 0 row_major, 1 col_major - int32_t getSLayoutValueI32() const; // 0 none_box, 1 row_major, 2 col_major + int32_t getSLayoutValueI32() const; // 0 none_box, 1 row_major, 2 col_major, 3 nc1hwc0, 4 ndc1hwc0 int32_t getPadValueI32() const; // 0 null, 1 zero, 2 max, 3 min int32_t getCompactModeI32() const; // 0 null, 1 normal, 2 row_plus_one }]; diff --git a/include/pto-c/Dialect/PTO.h b/include/pto-c/Dialect/PTO.h index fac8f2f57..9b62168e9 100644 --- a/include/pto-c/Dialect/PTO.h +++ b/include/pto-c/Dialect/PTO.h @@ -102,6 +102,9 @@ MLIR_CAPI_EXPORTED int32_t mlirPTOPadValueAttrGetValue(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirPTOAttrIsACompactModeAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirPTOCompactModeAttrGet(MlirContext ctx, int32_t value); MLIR_CAPI_EXPORTED int32_t mlirPTOCompactModeAttrGetValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirPTOAttrIsASetFmatrixModeAttr(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirAttribute mlirPTOSetFmatrixModeAttrGet(MlirContext ctx, int32_t value); +MLIR_CAPI_EXPORTED int32_t mlirPTOSetFmatrixModeAttrGetValue(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAAccToVecModeAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirPTOAccToVecModeAttrGet(MlirContext ctx, int32_t value); MLIR_CAPI_EXPORTED int32_t mlirPTOAccToVecModeAttrGetValue(MlirAttribute attr); @@ -193,6 +196,11 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirPTOTileBufConfigAttrGetWithCompactMode( MlirAttribute bLayout, MlirAttribute sLayout, MlirAttribute sFractalSize, MlirAttribute pad, MlirAttribute compactMode); +MLIR_CAPI_EXPORTED MlirAttribute mlirPTOTileBufConfigAttrGetWithImg2colConfig( + MlirContext ctx, + MlirAttribute bLayout, MlirAttribute sLayout, + MlirAttribute sFractalSize, MlirAttribute pad, + MlirAttribute compactMode, MlirAttribute img2colConfig); MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufTypeGetWithValidShape( MlirContext ctx, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute memorySpace, intptr_t validRank, const int64_t *validShape); diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index e9b7e2535..12026c6f1 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -103,7 +103,9 @@ static void bindPTOModule(pybind11::module &m) { py::enum_(m, "SLayout") .value("NoneBox", mlir::pto::SLayout::NoneBox) .value("RowMajor", mlir::pto::SLayout::RowMajor) - .value("ColMajor", mlir::pto::SLayout::ColMajor); + .value("ColMajor", mlir::pto::SLayout::ColMajor) + .value("NC1HWC0", mlir::pto::SLayout::NC1HWC0) + .value("NDC1HWC0", mlir::pto::SLayout::NDC1HWC0); py::enum_(m, "PadValue") .value("Null", mlir::pto::PadValue::Null) @@ -111,6 +113,12 @@ static void bindPTOModule(pybind11::module &m) { .value("Max", mlir::pto::PadValue::Max) .value("Min", mlir::pto::PadValue::Min); + py::enum_(m, "SetFmatrixMode") + .value("FMATRIX_A_AUTO", mlir::pto::SetFmatrixMode::FMATRIX_A_AUTO) + .value("FMATRIX_B_AUTO", mlir::pto::SetFmatrixMode::FMATRIX_B_AUTO) + .value("FMATRIX_A_MANUAL", mlir::pto::SetFmatrixMode::FMATRIX_A_MANUAL) + .value("FMATRIX_B_MANUAL", mlir::pto::SetFmatrixMode::FMATRIX_B_MANUAL); + py::enum_(m, "CompactMode") .value("Null", mlir::pto::CompactMode::Null) .value("Normal", mlir::pto::CompactMode::Normal) @@ -286,6 +294,33 @@ static void bindPTOModule(pybind11::module &m) { }, py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); + mlir_attribute_subclass(m, "SetFmatrixModeAttr", + [](MlirAttribute a) -> bool { + return mlirPTOAttrIsASetFmatrixModeAttr(a); + }) + .def_classmethod( + "get", + [](py::object cls, py::object value, MlirContext ctx) -> py::object { + int32_t v = 0; + if (py::isinstance(value)) { + v = value.cast(); + } else if (py::hasattr(value, "value")) { + v = value.attr("value").cast(); + } else { + throw std::runtime_error( + "SetFmatrixModeAttr.get expects int or SetFmatrixMode enum"); + } + MlirAttribute a = mlirPTOSetFmatrixModeAttrGet(ctx, v); + if (mlirAttributeIsNull(a)) return py::none(); + return cls.attr("__call__")(a); + }, + py::arg("cls"), py::arg("value"), py::arg("context") = py::none()) + .def_property_readonly( + "value", + [](MlirAttribute self) -> int32_t { + return mlirPTOSetFmatrixModeAttrGetValue(self); + }); + mlir_attribute_subclass(m, "AccToVecModeAttr", [](MlirAttribute a) -> bool { return mlirPTOAttrIsAAccToVecModeAttr(a); @@ -846,11 +881,13 @@ static void bindPTOModule(pybind11::module &m) { int32_t s_fractal_size, MlirAttribute pad, MlirContext ctx, - py::object compactModeObj) -> py::object { + py::object compactModeObj, + py::object img2colConfigObj) -> py::object { MlirType i32 = mlirIntegerTypeGet(ctx, 32); MlirAttribute sz = mlirIntegerAttrGet(i32, s_fractal_size); MlirAttribute compactMode = mlirPTOCompactModeAttrGet( ctx, static_cast(mlir::pto::CompactMode::Null)); + MlirAttribute img2colConfig = MlirAttribute{nullptr}; if (!compactModeObj.is_none()) { if (py::isinstance(compactModeObj)) { compactMode = mlirPTOCompactModeAttrGet( @@ -862,8 +899,12 @@ static void bindPTOModule(pybind11::module &m) { compactMode = compactModeObj.cast(); } } - MlirAttribute a = mlirPTOTileBufConfigAttrGetWithCompactMode( - ctx, blayout, slayout, sz, pad, compactMode); + if (!img2colConfigObj.is_none()) + img2colConfig = img2colConfigObj.cast(); + MlirAttribute a = + mlirPTOTileBufConfigAttrGetWithImg2colConfig( + ctx, blayout, slayout, sz, pad, compactMode, + img2colConfig); if (mlirAttributeIsNull(a)) return py::none(); return cls(a); }, @@ -873,7 +914,8 @@ static void bindPTOModule(pybind11::module &m) { py::arg("s_fractal_size"), py::arg("pad"), py::arg("context") = py::none(), - py::arg("compact_mode") = py::none()); + py::arg("compact_mode") = py::none(), + py::arg("img2col_config") = py::none()); // ---- TileBufType ---- mlir_type_subclass(m, "TileBufType", diff --git a/lib/CAPI/Dialect/PTO.cpp b/lib/CAPI/Dialect/PTO.cpp index 2ca0fef71..b88ff177d 100644 --- a/lib/CAPI/Dialect/PTO.cpp +++ b/lib/CAPI/Dialect/PTO.cpp @@ -618,6 +618,21 @@ int32_t mlirPTOCompactModeAttrGetValue(MlirAttribute attr) { return static_cast(a.getValue()); } +bool mlirPTOAttrIsASetFmatrixModeAttr(MlirAttribute attr) { + return mlir::isa(unwrap(attr)); +} + +MlirAttribute mlirPTOSetFmatrixModeAttrGet(MlirContext ctx, int32_t value) { + auto *c = unwrap(ctx); + return wrap(mlir::pto::SetFmatrixModeAttr::get( + c, static_cast(value))); +} + +int32_t mlirPTOSetFmatrixModeAttrGetValue(MlirAttribute attr) { + auto a = mlir::cast(unwrap(attr)); + return static_cast(a.getValue()); +} + bool mlirPTOAttrIsAAccToVecModeAttr(MlirAttribute attr) { return mlir::isa(unwrap(attr)); } @@ -723,6 +738,15 @@ MlirAttribute mlirPTOTileBufConfigAttrGet(MlirContext ctx, MlirAttribute mlirPTOTileBufConfigAttrGetWithCompactMode( MlirContext ctx, MlirAttribute bLayout, MlirAttribute sLayout, MlirAttribute sFractalSize, MlirAttribute pad, MlirAttribute compactMode) { + return mlirPTOTileBufConfigAttrGetWithImg2colConfig( + ctx, bLayout, sLayout, sFractalSize, pad, compactMode, + MlirAttribute{nullptr}); +} + +MlirAttribute mlirPTOTileBufConfigAttrGetWithImg2colConfig( + MlirContext ctx, MlirAttribute bLayout, MlirAttribute sLayout, + MlirAttribute sFractalSize, MlirAttribute pad, MlirAttribute compactMode, + MlirAttribute img2colConfig) { auto *c = unwrap(ctx); auto blA = toBLayoutAttr(c, unwrap(bLayout)); auto slA = toSLayoutAttr(c, unwrap(sLayout)); @@ -735,7 +759,8 @@ MlirAttribute mlirPTOTileBufConfigAttrGetWithCompactMode( if (!sz || !sz.getType().isInteger(kI32BitWidth)) return MlirAttribute{nullptr}; - return wrap(mlir::pto::TileBufConfigAttr::get(c, blA, slA, sz, pvA, cmA)); + return wrap(mlir::pto::TileBufConfigAttr::get( + c, blA, slA, sz, pvA, cmA, unwrap(img2colConfig))); } MlirType mlirPTOGMTypeGet(MlirContext ctx, intptr_t rank, const int64_t *shape, diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 5427ac36e..d9049be33 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -1469,6 +1469,20 @@ static LogicalResult verifyTileBufLayoutConstraints(Operation *op, << " to have concrete tile layout attributes"; constexpr int64_t kAlignedBytes = 32; + if (slayout == static_cast(SLayout::NC1HWC0) || + slayout == static_cast(SLayout::NDC1HWC0)) { + auto loc = getPTOMemorySpaceEnum(tb); + if (!loc || *loc != pto::AddressSpace::MAT) + return op->emitOpError() + << "expects " << name + << " nc1hwc0/ndc1hwc0 tiles to use MAT memory space"; + if (!tb.hasImg2colConfig()) + return op->emitOpError() + << "expects " << name + << " nc1hwc0/ndc1hwc0 tiles to carry img2col_config metadata"; + return success(); + } + auto checkByteAlignment = [&](int64_t dim, StringRef layoutName, StringRef byteExpr) -> LogicalResult { if (dim == ShapedType::kDynamic) @@ -1540,6 +1554,84 @@ static LogicalResult verifyTileBufLayoutConstraints(Operation *op, return success(); } +static std::optional getImg2colIntField(DictionaryAttr dict, + StringRef key) { + if (!dict) + return std::nullopt; + auto attr = dyn_cast_or_null(dict.get(key)); + if (!attr) + return std::nullopt; + return attr.getInt(); +} + +static LogicalResult verifyImg2colConfigLike(Operation *op, Type ty, + StringRef name) { + auto tb = dyn_cast(ty); + if (!tb) + return op->emitError("expects img2col config operand to be a tile_buf"); + auto dict = tb.getImg2colConfigAttr(); + if (!dict) + return op->emitError() << "expects " << name + << " to carry img2col_config metadata"; + if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NC1HWC0) && + tb.getSLayoutValueI32() != static_cast(pto::SLayout::NDC1HWC0)) + return op->emitError() << "expects " << name + << " slayout to be nc1hwc0 or ndc1hwc0"; + if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitError() << "expects " << name + << " b_layout to be row_major"; + + auto requirePositive = [&](StringRef key) -> LogicalResult { + auto value = getImg2colIntField(dict, key); + if (!value.has_value() || *value <= 0) + return op->emitError() << "expects " << name << " img2col_config." << key + << " to be a positive integer"; + return success(); + }; + auto requireNonNegative = [&](StringRef key) -> LogicalResult { + auto value = getImg2colIntField(dict, key); + if (!value.has_value() || *value < 0) + return op->emitError() << "expects " << name << " img2col_config." << key + << " to be a non-negative integer"; + return success(); + }; + auto requireBoolLike = [&](StringRef key) -> LogicalResult { + auto value = getImg2colIntField(dict, key); + if (!value.has_value() || (*value != 0 && *value != 1)) + return op->emitError() << "expects " << name << " img2col_config." << key + << " to be 0 or 1"; + return success(); + }; + + if (failed(requirePositive("fmap_h")) || failed(requirePositive("fmap_w")) || + failed(requirePositive("stride_h")) || + failed(requirePositive("stride_w")) || + failed(requirePositive("dilation_h")) || + failed(requirePositive("dilation_w")) || + failed(requirePositive("filter_h")) || + failed(requirePositive("filter_w")) || + failed(requirePositive("channel_size")) || + failed(requireNonNegative("dst_stride")) || + failed(requireNonNegative("dst_m_position")) || + failed(requireNonNegative("repeat_stride")) || + failed(requireNonNegative("repeat_time")) || + failed(requireNonNegative("repeat_mode")) || + failed(requireBoolLike("transpose"))) + return failure(); + + return success(); +} + +static bool isSupportedTImg2colElemType(Type ty) { + if (ty.isF16() || ty.isBF16() || ty.isF32()) + return true; + auto intTy = dyn_cast(ty); + if (!intTy) + return false; + unsigned width = intTy.getWidth(); + return width == 8 || width == 16 || width == 32; +} + [[maybe_unused]] static bool isSupportedLoadStoreElemTypeA2A3(Type ty) { if (ty.isF16() || ty.isBF16() || ty.isF32()) return true; @@ -5937,6 +6029,80 @@ mlir::LogicalResult mlir::pto::TFillPadInplaceOp::verify() { /*allowDstExpand=*/false, "tfillpad_inplace"); } +static LogicalResult verifyTImg2colConfigOp(Operation *op, Type configTy, + StringRef name) { + if (getVerifierTargetArch(op) != VerifierTargetArch::A5) + return op->emitError("expects A5 img2col configuration ops"); + if (failed(verifyImg2colConfigLike(op, configTy, name))) + return failure(); + + auto loc = getPTOMemorySpaceEnum(configTy); + if (!loc || *loc != pto::AddressSpace::MAT) + return op->emitError() << "expects " << name + << " to use loc=mat for img2col configuration"; + + auto tb = cast(configTy); + if (tb.getPadValueI32() != static_cast(pto::PadValue::Null) && + tb.getPadValueI32() != static_cast(pto::PadValue::Zero)) + return op->emitError() + << "expects img2col padding mode to be null or zero for now"; + + return success(); +} + +LogicalResult mlir::pto::SetImg2colFmatrixOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + return verifyTImg2colConfigOp(getOperation(), getConfig().getType(), "config"); +} + +LogicalResult mlir::pto::SetImg2colPaddingOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + return verifyTImg2colConfigOp(getOperation(), getConfig().getType(), "config"); +} + +LogicalResult mlir::pto::SetImg2colRepeatOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + return verifyTImg2colConfigOp(getOperation(), getConfig().getType(), "config"); +} + +LogicalResult mlir::pto::TImg2colOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (getVerifierTargetArch(getOperation()) != VerifierTargetArch::A5) + return emitOpError("expects A5 timg2col"); + + auto srcTb = dyn_cast(getSrc().getType()); + auto dstTb = dyn_cast(getDst().getType()); + if (!srcTb || !dstTb) + return emitOpError("expects src/dst to be tile_buf types"); + + if (failed(verifyImg2colConfigLike(getOperation(), srcTb, "src"))) + return failure(); + + auto srcLoc = getPTOMemorySpaceEnum(srcTb); + auto dstLoc = getPTOMemorySpaceEnum(dstTb); + if (!srcLoc || *srcLoc != pto::AddressSpace::MAT) + return emitOpError("expects src to use loc=mat"); + if (!dstLoc || *dstLoc != pto::AddressSpace::LEFT) + return emitOpError("expects dst to use loc=left"); + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return emitOpError("expects dst b_layout to be col_major"); + + if (srcTb.getElementType() != dstTb.getElementType()) + return emitOpError("expects src/dst element types to match"); + if (!isSupportedTImg2colElemType(srcTb.getElementType())) + return emitOpError( + "expects src/dst element types to be i8/i16/i32/f16/bf16/f32"); + + if (dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return emitOpError("expects dst slayout to be row_major"); + + return success(); +} + llvm::LogicalResult mlir::pto::TGatherOp::verify() { auto isSupportedGatherElemTypeA5Index = [&](Type ty) -> bool { @@ -11282,6 +11448,12 @@ void TPrintOp::getEffects( PTO_ADD_WRITE(getSrcMutable()); } +void TImg2colOp::getEffects( + SmallVectorImpl> &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + #undef PTO_DEFINE_TERNARY_EFFECTS #undef PTO_DEFINE_BINARY_EFFECTS #undef PTO_DEFINE_UNARY_EFFECTS diff --git a/lib/PTO/IR/PTOAttrs.cpp b/lib/PTO/IR/PTOAttrs.cpp index cc4915f83..aa2b69d43 100644 --- a/lib/PTO/IR/PTOAttrs.cpp +++ b/lib/PTO/IR/PTOAttrs.cpp @@ -29,7 +29,10 @@ constexpr int32_t kFractalSize1024 = 1024; constexpr int32_t kBLayoutRowMajor = static_cast(BLayout::RowMajor); constexpr int32_t kBLayoutColMajor = static_cast(BLayout::ColMajor); constexpr int32_t kSLayoutNoneBox = static_cast(SLayout::NoneBox); +constexpr int32_t kSLayoutRowMajor = static_cast(SLayout::RowMajor); constexpr int32_t kSLayoutColMajor = static_cast(SLayout::ColMajor); +constexpr int32_t kSLayoutNC1HWC0 = static_cast(SLayout::NC1HWC0); +constexpr int32_t kSLayoutNDC1HWC0 = static_cast(SLayout::NDC1HWC0); constexpr int32_t kPadValueNull = static_cast(PadValue::Null); constexpr int32_t kPadValueMin = static_cast(PadValue::Min); constexpr int32_t kCompactModeNull = static_cast(CompactMode::Null); @@ -45,7 +48,8 @@ TileBufConfigAttr TileBufConfigAttr::getDefault(MLIRContext *ctx) { PadValueAttr pv = PadValueAttr::get(ctx, PadValue::Null); CompactModeAttr compact = CompactModeAttr::get(ctx, CompactMode::Null); IntegerAttr sz = b.getI32IntegerAttr(kFractalSize512); - return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact); + return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact, + Attribute()); } bool TileBufConfigAttr::isDefault() const { @@ -54,7 +58,17 @@ bool TileBufConfigAttr::isDefault() const { getSLayout() == d.getSLayout() && getSFractalSize() == d.getSFractalSize() && getPad() == d.getPad() && - getCompactMode() == d.getCompactMode(); + getCompactMode() == d.getCompactMode() && + getImg2colConfig() == d.getImg2colConfig(); +} + +DictionaryAttr TileBufConfigAttr::getImg2colConfigDict() const { + return mlir::dyn_cast_or_null(getImg2colConfig()); +} + +bool TileBufConfigAttr::hasImg2colConfig() const { + auto dict = getImg2colConfigDict(); + return dict && !dict.empty(); } static int32_t getLayoutInt(Attribute a, int32_t def) { @@ -71,7 +85,8 @@ LogicalResult TileBufConfigAttr::verify(function_ref emitE Attribute sLayout, IntegerAttr sFractalSize, Attribute pad, - Attribute compactMode) { + Attribute compactMode, + Attribute img2colConfig) { if (!bLayout || (!mlir::isa(bLayout) && !mlir::isa(bLayout))) return emitError() << "blayout must be BLayoutAttr or i32 integer attr", failure(); if (!sLayout || (!mlir::isa(sLayout) && !mlir::isa(sLayout))) @@ -96,7 +111,7 @@ LogicalResult TileBufConfigAttr::verify(function_ref emitE return emitError() << "unsupported blayout value: " << blv, failure(); int32_t slv = getLayoutInt(sLayout, -1); - if (slv < kSLayoutNoneBox || slv > kSLayoutColMajor) + if (slv < kSLayoutNoneBox || slv > kSLayoutNDC1HWC0) return emitError() << "unsupported slayout value: " << slv, failure(); int32_t pvv = getLayoutInt(pad, -1); @@ -107,6 +122,21 @@ LogicalResult TileBufConfigAttr::verify(function_ref emitE if (cmv < kCompactModeNull || cmv > kCompactModeRowPlusOne) return emitError() << "unsupported compact_mode value: " << cmv, failure(); + if (img2colConfig) { + auto dict = mlir::dyn_cast(img2colConfig); + if (!dict) { + emitError() << "img2col_config must be a dictionary attribute"; + return failure(); + } + for (NamedAttribute named : dict) { + if (!mlir::isa(named.getValue())) { + emitError() << "img2col_config field '" << named.getName() + << "' must be an integer attribute"; + return failure(); + } + } + } + return success(); } @@ -141,11 +171,13 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) { IntegerAttr sz = def.getSFractalSize(); PadValueAttr pv = def.getPad(); CompactModeAttr compact = def.getCompactMode(); + Attribute img2colConfig = def.getImg2colConfig(); if (p.parseLess()) return {}; if (succeeded(p.parseOptionalGreater())) - return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact); + return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact, + img2colConfig); bool parsedGreater = false; while (!parsedGreater) { @@ -177,6 +209,8 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) { if (p.parseAttribute(a)) return {}; compact = toCompactModeAttr(ctx, a); if (!compact) return {}; + } else if (key == "img2col_config") { + if (p.parseAttribute(img2colConfig)) return {}; } else { p.emitError(p.getCurrentLocation(), "unknown key in tile_buf_config: ") << key; return {}; @@ -188,7 +222,7 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) { if (p.parseComma()) return {}; } - return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact); + return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact, img2colConfig); } void TileBufConfigAttr::print(AsmPrinter &p) const { @@ -198,5 +232,7 @@ void TileBufConfigAttr::print(AsmPrinter &p) const { p << ", s_fractal_size=" << (int32_t)getSFractalSize().getInt(); p << ", pad=" << getPad(); p << ", compact=" << getCompactMode(); + if (auto dict = getImg2colConfigDict()) + p << ", img2col_config=" << dict; p << ">"; } diff --git a/lib/PTO/IR/PTOTypeDefs.cpp b/lib/PTO/IR/PTOTypeDefs.cpp index a3d5ab596..5526095fc 100644 --- a/lib/PTO/IR/PTOTypeDefs.cpp +++ b/lib/PTO/IR/PTOTypeDefs.cpp @@ -210,6 +210,12 @@ mlir::Attribute TileBufType::getPadValueAttr() const { return getConfigAttr().ge mlir::Attribute TileBufType::getCompactModeAttr() const { return getConfigAttr().getCompactMode(); } +mlir::DictionaryAttr TileBufType::getImg2colConfigAttr() const { + return getConfigAttr().getImg2colConfigDict(); +} +bool TileBufType::hasImg2colConfig() const { + return getConfigAttr().hasImg2colConfig(); +} // ✅ numeric getters(可选) int32_t TileBufType::getSFractalSizeI32() const { @@ -254,8 +260,122 @@ struct ParsedTileBufFields { int64_t fractal = 0; uint32_t padInt = 0; uint32_t compactInt = 0; + bool hasImg2colConfig = false; + int64_t fmapH = 0; + int64_t fmapW = 0; + int64_t padL = 0; + int64_t padR = 0; + int64_t padT = 0; + int64_t padB = 0; + int64_t strideH = 1; + int64_t strideW = 1; + int64_t dilationH = 1; + int64_t dilationW = 1; + int64_t filterH = 1; + int64_t filterW = 1; + int64_t channelSize = 0; + int64_t dstStride = 0; + int64_t dstMPosition = 0; + int64_t repeatStride = 0; + int64_t repeatTime = 0; + int64_t repeatMode = 0; + int64_t transpose = 0; }; +static LogicalResult parseTileBufImg2colField(AsmParser &parser, StringRef key, + ParsedTileBufFields &fields) { + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + fields.hasImg2colConfig = true; + + if (key == "fmap_h") + fields.fmapH = value; + else if (key == "fmap_w") + fields.fmapW = value; + else if (key == "pad_l") + fields.padL = value; + else if (key == "pad_r") + fields.padR = value; + else if (key == "pad_t") + fields.padT = value; + else if (key == "pad_b") + fields.padB = value; + else if (key == "stride_h") + fields.strideH = value; + else if (key == "stride_w") + fields.strideW = value; + else if (key == "dilation_h") + fields.dilationH = value; + else if (key == "dilation_w") + fields.dilationW = value; + else if (key == "filter_h") + fields.filterH = value; + else if (key == "filter_w") + fields.filterW = value; + else if (key == "channel_size") + fields.channelSize = value; + else if (key == "dst_stride") + fields.dstStride = value; + else if (key == "dst_m_position") + fields.dstMPosition = value; + else if (key == "repeat_stride") + fields.repeatStride = value; + else if (key == "repeat_time") + fields.repeatTime = value; + else if (key == "repeat_mode") + fields.repeatMode = value; + else if (key == "transpose") + fields.transpose = value; + else + return failure(); + + return success(); +} + +static bool isTileBufImg2colKey(StringRef key) { + return key == "fmap_h" || key == "fmap_w" || key == "pad_l" || + key == "pad_r" || key == "pad_t" || key == "pad_b" || + key == "stride_h" || key == "stride_w" || key == "dilation_h" || + key == "dilation_w" || key == "filter_h" || key == "filter_w" || + key == "channel_size" || key == "dst_stride" || + key == "dst_m_position" || key == "repeat_stride" || + key == "repeat_time" || key == "repeat_mode" || key == "transpose"; +} + +static DictionaryAttr buildImg2colConfigAttr(MLIRContext *ctx, + const ParsedTileBufFields &fields) { + if (!fields.hasImg2colConfig) + return {}; + + Builder b(ctx); + NamedAttrList attrs; + auto i64 = [&](StringRef name, int64_t value) { + attrs.set(name, b.getI64IntegerAttr(value)); + }; + + i64("fmap_h", fields.fmapH); + i64("fmap_w", fields.fmapW); + i64("pad_l", fields.padL); + i64("pad_r", fields.padR); + i64("pad_t", fields.padT); + i64("pad_b", fields.padB); + i64("stride_h", fields.strideH); + i64("stride_w", fields.strideW); + i64("dilation_h", fields.dilationH); + i64("dilation_w", fields.dilationW); + i64("filter_h", fields.filterH); + i64("filter_w", fields.filterW); + i64("channel_size", fields.channelSize); + i64("dst_stride", fields.dstStride); + i64("dst_m_position", fields.dstMPosition); + i64("repeat_stride", fields.repeatStride); + i64("repeat_time", fields.repeatTime); + i64("repeat_mode", fields.repeatMode); + i64("transpose", fields.transpose); + return DictionaryAttr::get(ctx, attrs); +} + static LogicalResult parseTileBufUInt32Value(AsmParser &parser, StringRef key, uint32_t &value) { int64_t parsedValue = 0; @@ -294,6 +414,20 @@ static LogicalResult parseLegacyTileBufFields(AsmParser &parser, return success(); } +static LogicalResult parseLegacyOptionalTileBufField(AsmParser &parser, + StringRef key, + ParsedTileBufFields &fields) { + if (key == "compact") + return parseTileBufUInt32Value(parser, key, fields.compactInt); + if (isTileBufImg2colKey(key)) + return parseTileBufImg2colField(parser, key, fields); + + parser.emitError(parser.getCurrentLocation(), + "unknown key in tile_buf legacy syntax: ") + << key; + return failure(); +} + static LogicalResult parseCompactTileBufFields(AsmParser &parser, StringRef firstToken, ParsedTileBufFields &fields) { @@ -430,6 +564,12 @@ static LogicalResult parseCompactTileBufFields(AsmParser &parser, continue; } + if (isTileBufImg2colKey(key)) { + if (failed(parseTileBufImg2colField(parser, key, fields))) + return failure(); + continue; + } + parser.emitError(parser.getCurrentLocation(), "unknown key in tile_buf compact syntax: ") << key; @@ -489,8 +629,9 @@ static Type buildTileBufType(AsmParser &parser, auto padAttr = PadValueAttr::get(ctx, pv.value()); auto compactAttr = CompactModeAttr::get(ctx, compact.value()); auto memorySpaceAttr = AddressSpaceAttr::get(ctx, memorySpace.value()); - auto cfg = TileBufConfigAttr::get(ctx, blAttr, slAttr, fractalAttr, padAttr, - compactAttr); + auto cfg = TileBufConfigAttr::get( + ctx, blAttr, slAttr, fractalAttr, padAttr, compactAttr, + buildImg2colConfigAttr(ctx, fields)); TileBufShape shape{fields.rows, fields.cols}; TileBufShape validShape{fields.vrow, fields.vcol}; @@ -523,10 +664,13 @@ Type TileBufType::parse(AsmParser &parser) { return Type(); } - if (isLegacySyntax && succeeded(parser.parseOptionalComma())) { - if (failed(parseTileBufKeyEq(parser, "compact")) || - failed(parseTileBufUInt32Value(parser, "compact", fields.compactInt))) { - return Type(); + if (isLegacySyntax) { + while (succeeded(parser.parseOptionalComma())) { + StringRef key; + if (failed(parser.parseKeyword(&key)) || failed(parser.parseEqual())) + return Type(); + if (failed(parseLegacyOptionalTileBufField(parser, key, fields))) + return Type(); } } @@ -609,6 +753,7 @@ void mlir::pto::TileBufType::print(mlir::AsmPrinter &printer) const { auto defaultPad = llvm::dyn_cast(defaultCfg.getPad()); auto defaultCompact = llvm::dyn_cast(defaultCfg.getCompactMode()); + auto img2colConfig = cfg.getImg2colConfigDict(); auto vs = getValidShape(); int64_t vrow = rows; @@ -654,6 +799,32 @@ void mlir::pto::TileBufType::print(mlir::AsmPrinter &printer) const { printer << ", pad=" << stringifyLocFromPad(cfg.getPad()); if (printCompact) printer << ", compact=" << stringifyCompactModeInt(cfg.getCompactMode()); + if (img2colConfig) { + auto printImg2col = [&](StringRef key) { + if (auto value = + mlir::dyn_cast_or_null(img2colConfig.get(key))) + printer << ", " << key << "=" << value.getInt(); + }; + printImg2col("fmap_h"); + printImg2col("fmap_w"); + printImg2col("pad_l"); + printImg2col("pad_r"); + printImg2col("pad_t"); + printImg2col("pad_b"); + printImg2col("stride_h"); + printImg2col("stride_w"); + printImg2col("dilation_h"); + printImg2col("dilation_w"); + printImg2col("filter_h"); + printImg2col("filter_w"); + printImg2col("channel_size"); + printImg2col("dst_stride"); + printImg2col("dst_m_position"); + printImg2col("repeat_stride"); + printImg2col("repeat_time"); + printImg2col("repeat_mode"); + printImg2col("transpose"); + } printer << ">"; } diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 906e679f3..f90ed3b99 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -469,6 +469,8 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { pto::BuildAsyncSessionOp, pto::TPutAsyncOp, pto::TGetAsyncOp, pto::TPutOp, pto::TGetOp, pto::TNotifyOp, pto::TWaitOp, pto::TTestOp, + pto::SetImg2colFmatrixOp, pto::SetImg2colPaddingOp, + pto::SetImg2colRepeatOp, pto::SyncAllOp, pto::TBroadcastOp, pto::CommTGatherOp, pto::CommTScatterOp, pto::TReduceOp>(op)) { diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index b3f0c6bbd..78919d083 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -423,6 +423,165 @@ static std::optional getEmitCTileTypeString(pto::TileBufType type) tileBufCompactToken(configAttr) + ">"; } +struct Img2colConfigData { + int64_t fmapH = 0; + int64_t fmapW = 0; + int64_t padL = 0; + int64_t padR = 0; + int64_t padT = 0; + int64_t padB = 0; + int64_t strideH = 1; + int64_t strideW = 1; + int64_t dilationH = 1; + int64_t dilationW = 1; + int64_t filterH = 1; + int64_t filterW = 1; + int64_t channelSize = 0; + int64_t dstStride = 0; + int64_t dstMPosition = 0; + int64_t repeatStride = 0; + int64_t repeatTime = 0; + int64_t repeatMode = 0; + int64_t transpose = 0; +}; + +static std::optional getImg2colConfigInt(DictionaryAttr dict, + StringRef key) { + if (!dict) + return std::nullopt; + if (auto attr = dyn_cast_or_null(dict.get(key))) + return attr.getInt(); + return std::nullopt; +} + +static std::optional +resolveTileBufConfigFromValue(Value value) { + value = peelUnrealized(value); + while (Operation *def = value.getDefiningOp()) { + if (auto bind = dyn_cast(def)) { + if (auto config = bind.getConfigAttr()) + return config; + value = peelUnrealized(bind.getSource()); + continue; + } + if (auto cast = dyn_cast(def)) { + if (auto config = cast.getConfig()) + return *config; + break; + } + if (auto subview = dyn_cast(def)) { + value = peelUnrealized(subview.getSource()); + continue; + } + if (auto reinterpret = dyn_cast(def)) { + value = peelUnrealized(reinterpret.getSource()); + continue; + } + if (auto cast = dyn_cast(def)) { + value = peelUnrealized(cast.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + break; + value = peelUnrealized(unrealized.getOperand(0)); + continue; + } + break; + } + + if (auto tileTy = dyn_cast(value.getType())) + return tileTy.getConfigAttr(); + return std::nullopt; +} + +static std::optional +getImg2colConfigData(pto::TileBufConfigAttr configAttr) { + if (!configAttr) + return std::nullopt; + auto dict = configAttr.getImg2colConfigDict(); + if (!dict) + return std::nullopt; + + auto getRequired = [&](StringRef key, int64_t fallback = 0) + -> std::optional { + if (auto value = getImg2colConfigInt(dict, key)) + return value; + return fallback; + }; + + Img2colConfigData data; + auto fmapH = getRequired("fmap_h"); + auto fmapW = getRequired("fmap_w"); + auto channelSize = getRequired("channel_size"); + auto repeatStride = getRequired("repeat_stride"); + auto repeatTime = getRequired("repeat_time"); + auto repeatMode = getRequired("repeat_mode"); + auto dstStride = getRequired("dst_stride"); + if (!fmapH || !fmapW || !channelSize || !repeatStride || !repeatTime || + !repeatMode || !dstStride) + return std::nullopt; + + data.fmapH = *fmapH; + data.fmapW = *fmapW; + data.padL = getRequired("pad_l", 0).value_or(0); + data.padR = getRequired("pad_r", 0).value_or(0); + data.padT = getRequired("pad_t", 0).value_or(0); + data.padB = getRequired("pad_b", 0).value_or(0); + data.strideH = getRequired("stride_h", 1).value_or(1); + data.strideW = getRequired("stride_w", 1).value_or(1); + data.dilationH = getRequired("dilation_h", 1).value_or(1); + data.dilationW = getRequired("dilation_w", 1).value_or(1); + data.filterH = getRequired("filter_h", 1).value_or(1); + data.filterW = getRequired("filter_w", 1).value_or(1); + data.channelSize = *channelSize; + data.dstStride = *dstStride; + data.dstMPosition = getRequired("dst_m_position", 0).value_or(0); + data.repeatStride = *repeatStride; + data.repeatTime = *repeatTime; + data.repeatMode = *repeatMode; + data.transpose = getRequired("transpose", 0).value_or(0); + return data; +} + +static std::optional getImg2colConfigData(Value value) { + auto configAttr = resolveTileBufConfigFromValue(value); + if (!configAttr) + return std::nullopt; + return getImg2colConfigData(*configAttr); +} + +static pto::SetFmatrixMode getImg2colMode(Attribute modeAttr) { + if (auto attr = dyn_cast_or_null(modeAttr)) + return attr.getValue(); + return pto::SetFmatrixMode::FMATRIX_A_MANUAL; +} + +static bool isImg2colBMode(pto::SetFmatrixMode mode) { + return mode == pto::SetFmatrixMode::FMATRIX_B_AUTO || + mode == pto::SetFmatrixMode::FMATRIX_B_MANUAL; +} + +static StringRef img2colModeToken(pto::SetFmatrixMode mode) { + switch (mode) { + case pto::SetFmatrixMode::FMATRIX_A_AUTO: + return "SetFmatrixMode::FMATRIX_A_AUTO"; + case pto::SetFmatrixMode::FMATRIX_B_AUTO: + return "SetFmatrixMode::FMATRIX_B_AUTO"; + case pto::SetFmatrixMode::FMATRIX_A_MANUAL: + return "SetFmatrixMode::FMATRIX_A_MANUAL"; + case pto::SetFmatrixMode::FMATRIX_B_MANUAL: + return "SetFmatrixMode::FMATRIX_B_MANUAL"; + } + llvm_unreachable("unknown img2col mode"); +} + +static ArrayAttr buildImg2colUseBTemplateArgs(ConversionPatternRewriter &rewriter, + bool useB) { + return rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(rewriter.getContext(), useB ? "true" : "false")}); +} + //===----------------------------------------------------------------------===// // Type Converter //===----------------------------------------------------------------------===// @@ -3702,6 +3861,8 @@ static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr) { int32_t slVal = static_cast(slAttr.getValue()); slTok = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" + : (slVal == 3) ? "SLayout::NC1HWC0" + : (slVal == 4) ? "SLayout::NDC1HWC0" : "SLayout::NoneBox"; } return slTok; @@ -4019,7 +4180,11 @@ struct PointerCastConversion : public OpConversionPattern { if (auto attr = dyn_cast(config.getSLayout())) slVal = static_cast(attr.getValue()); - std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; + std::string slStr = (slVal == 1) ? "SLayout::RowMajor" + : (slVal == 2) ? "SLayout::ColMajor" + : (slVal == 3) ? "SLayout::NC1HWC0" + : (slVal == 4) ? "SLayout::NDC1HWC0" + : "SLayout::NoneBox"; int32_t frVal = 0; if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); @@ -8349,6 +8514,156 @@ struct PTOInsertFPToEmitC : public OpConversionPattern { return success(); } }; + +static LogicalResult emitTImg2colFmatrix(ConversionPatternRewriter &rewriter, + Location loc, Value srcValue, + bool useB) { + auto config = getImg2colConfigData(srcValue); + if (!config) + return failure(); + + MLIRContext *ctx = rewriter.getContext(); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + SmallVector fmatrixArgs{ + makeEmitCIntConstant(rewriter, loc, i32Ty, config->fmapH), + makeEmitCIntConstant(rewriter, loc, i32Ty, config->fmapW), + makeEmitCIntConstant(rewriter, loc, i32Ty, config->padL), + makeEmitCIntConstant(rewriter, loc, i32Ty, config->padR), + makeEmitCIntConstant(rewriter, loc, i32Ty, config->padT), + makeEmitCIntConstant(rewriter, loc, i32Ty, config->padB)}; + auto templateArgs = buildImg2colUseBTemplateArgs(rewriter, useB); + rewriter.create( + loc, TypeRange{}, "PTOAS__TIMG2COL_SET_FMATRIX", ArrayAttr{}, + templateArgs, ValueRange{fmatrixArgs}); + return success(); +} + +static LogicalResult emitTImg2colPadding(ConversionPatternRewriter &rewriter, + Location loc, Value srcValue, + bool useB) { + auto configAttr = resolveTileBufConfigFromValue(srcValue); + if (!configAttr) + return failure(); + Type elemTy = srcValue.getType(); + if (auto tb = dyn_cast(elemTy)) + elemTy = tb.getElementType(); + else if (auto mr = dyn_cast(elemTy)) + elemTy = mr.getElementType(); + else + return failure(); + + MLIRContext *ctx = rewriter.getContext(); + auto padAttr = dyn_cast((*configAttr).getPad()); + if (padAttr && padAttr.getValue() == pto::PadValue::Zero) { + auto templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, getEmitCScalarTypeToken(elemTy)), + emitc::OpaqueAttr::get(ctx, useB ? "true" : "false") + }); + rewriter.create(loc, TypeRange{}, + "PTOAS__TIMG2COL_SET_PADDING_ZERO", + ArrayAttr{}, templateArgs, + ValueRange{}); + } + return success(); +} + +static LogicalResult emitTImg2colRpt(ConversionPatternRewriter &rewriter, + Location loc, Value srcValue, + bool useB) { + auto config = getImg2colConfigData(srcValue); + if (!config) + return failure(); + + MLIRContext *ctx = rewriter.getContext(); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + SmallVector repeatArgs{ + makeEmitCIntConstant(rewriter, loc, i32Ty, config->repeatStride), + makeEmitCIntConstant(rewriter, loc, i32Ty, config->repeatTime), + makeEmitCIntConstant(rewriter, loc, i32Ty, config->repeatMode), + makeEmitCIntConstant(rewriter, loc, i32Ty, config->dstStride), + makeEmitCIntConstant(rewriter, loc, i32Ty, config->dstMPosition)}; + auto templateArgs = buildImg2colUseBTemplateArgs(rewriter, useB); + rewriter.create( + loc, TypeRange{}, "PTOAS__TIMG2COL_SET_RPT", ArrayAttr{}, templateArgs, + ValueRange{repeatArgs}); + return success(); +} + +struct PTOSetImg2colFmatrixToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::SetImg2colFmatrixOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto mode = getImg2colMode(op.getFmatrixModeAttr()); + if (failed(emitTImg2colFmatrix(rewriter, op.getLoc(), op.getConfig(), + isImg2colBMode(mode)))) + return failure(); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSetImg2colPaddingToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::SetImg2colPaddingOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto mode = getImg2colMode(op.getFmatrixModeAttr()); + if (failed(emitTImg2colPadding(rewriter, op.getLoc(), op.getConfig(), + isImg2colBMode(mode)))) + return failure(); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSetImg2colRepeatToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::SetImg2colRepeatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto mode = getImg2colMode(op.getFmatrixModeAttr()); + if (failed(emitTImg2colRpt(rewriter, op.getLoc(), op.getConfig(), + isImg2colBMode(mode)))) + return failure(); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOTImg2colToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TImg2colOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + auto mode = getImg2colMode(op.getFmatrixModeAttr()); + + auto srcTy = dyn_cast(op.getSrc().getType()); + auto dstTy = dyn_cast(op.getDst().getType()); + if (!srcTy || !dstTy) + return failure(); + ArrayAttr templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(rewriter.getContext(), + getEmitCTileTypeString(dstTy).value_or("")), + emitc::OpaqueAttr::get(rewriter.getContext(), + getEmitCTileTypeString(srcTy).value_or("")), + emitc::OpaqueAttr::get(rewriter.getContext(), img2colModeToken(mode))}); + auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "uint16_t"); + Value posM = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, op.getPosM()); + Value posK = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, op.getPosK()); + rewriter.create(op.getLoc(), TypeRange{}, "TIMG2COL", + ArrayAttr{}, templateArgs, + ValueRange{dst, src, posM, posK}); + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // pto.tfillpad lowering -> TFILLPAD(dst, src) //===----------------------------------------------------------------------===// @@ -10821,13 +11136,7 @@ struct PTOBindTileToEmitC : public OpConversionPattern { } } - std::string slTok = "SLayout::NoneBox"; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) { - int32_t slVal = static_cast(slAttr.getValue()); - slTok = (slVal == 1) ? "SLayout::RowMajor" - : (slVal == 2) ? "SLayout::ColMajor" - : "SLayout::NoneBox"; - } + std::string slTok = tileBufSLayoutToken(configAttr); int32_t fractal = 512; if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) @@ -12255,6 +12564,9 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, + ctx); patterns.add( typeConverter, ctx); patterns.add(typeConverter, ctx); @@ -12500,6 +12812,7 @@ struct EmitPTOManualPass bool needsTRandomHelper = false; bool needsGlobalTensorDataHelper = false; bool needsCommInclude = false; + bool needsImg2colHelper = false; mop.walk([&](Operation *op) { if (isa(op)) needsEventIdArrayHelper = true; @@ -12507,6 +12820,10 @@ struct EmitPTOManualPass needsTRandomHelper = true; if (isa(op)) needsGlobalTensorDataHelper = true; + if (isa(op)) + needsImg2colHelper = true; if (isa(dst, key, counter); } +)cpp")); + } + if (needsImg2colHelper) { + builder.create( + loc, builder.getStringAttr(R"cpp( +template +static AICORE inline void PTOAS__TIMG2COL_SET_PADDING_ZERO() { + if constexpr (UseB) { + set_padding_b(static_cast(0)); + } else { + set_padding(static_cast(0)); + } +} + +template +static AICORE inline void PTOAS__TIMG2COL_SET_FMATRIX( + int32_t fmapH, int32_t fmapW, int32_t padL, int32_t padR, int32_t padT, + int32_t padB) { + uint64_t fmatrixReg = 0; + constexpr uint32_t l1ShiftBit = 16; + fmatrixReg |= uint64_t(fmapW & 0xFFFF); + fmatrixReg |= uint64_t(fmapH & 0xFFFF) << l1ShiftBit; + + constexpr uint32_t padListShiftBit = 8; + constexpr uint32_t padListShiftBase = 32; + fmatrixReg |= uint64_t(padL & 0xFF) << padListShiftBase; + fmatrixReg |= uint64_t(padR & 0xFF) << (padListShiftBase + padListShiftBit); + fmatrixReg |= uint64_t(padT & 0xFF) << (padListShiftBase + padListShiftBit * 2); + fmatrixReg |= uint64_t(padB & 0xFF) << (padListShiftBase + padListShiftBit * 3); + if constexpr (UseB) { + set_fmatrix_b(fmatrixReg); + } else { + set_fmatrix(fmatrixReg); + } +} + +template +static AICORE inline void PTOAS__TIMG2COL_SET_RPT( + int32_t repeatStride, int32_t repeatTime, int32_t repeatMode, + int32_t dstStride, int32_t dstMposition) { + uint64_t rptConfig = 0; + constexpr uint32_t repeatTimeShiftBit = 16; + constexpr uint32_t repeatModeShiftBit = 24; + constexpr uint32_t dstStrideShiftBit = 32; + constexpr uint32_t dstMpositionShiftBit = 48; + rptConfig |= uint64_t(repeatStride); + rptConfig |= uint64_t(repeatTime) << repeatTimeShiftBit; + rptConfig |= uint64_t(repeatMode) << repeatModeShiftBit; + rptConfig |= uint64_t(dstStride) << dstStrideShiftBit; + rptConfig |= uint64_t(dstMposition) << dstMpositionShiftBit; + if constexpr (UseB) { + set_l3d_rpt_b(rptConfig); + } else { + set_l3d_rpt(rptConfig); + } +} )cpp")); } builder.create( diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index a0b5037c7..cf70d7c51 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -69,6 +69,8 @@ def get_op_result_or_value(value): SLayoutAttr = _pto_mod.SLayoutAttr PadValue = _pto_mod.PadValue PadValueAttr = _pto_mod.PadValueAttr +SetFmatrixMode = _pto_mod.SetFmatrixMode +SetFmatrixModeAttr = _pto_mod.SetFmatrixModeAttr CompactMode = _pto_mod.CompactMode CompactModeAttr = _pto_mod.CompactModeAttr AccToVecMode = _pto_mod.AccToVecMode @@ -125,6 +127,8 @@ def get_op_result_or_value(value): "SLayoutAttr", "PadValue", "PadValueAttr", + "SetFmatrixMode", + "SetFmatrixModeAttr", "CompactMode", "CompactModeAttr", "AccToVecMode", @@ -159,6 +163,7 @@ def get_op_result_or_value(value): "QuantTypeAttr", "TileBufConfigAttr", "TileConfig", + "make_img2col_config", # High-level sync helpers "record_event", "wait_event", @@ -631,6 +636,61 @@ class TileConfig: fractalMxSize = 32 +def make_img2col_config( + *, + fmap_h, + fmap_w, + pad_l=0, + pad_r=0, + pad_t=0, + pad_b=0, + stride_h=1, + stride_w=1, + dilation_h=1, + dilation_w=1, + filter_h=1, + filter_w=1, + channel_size, + dst_stride, + dst_m_position=0, + repeat_stride, + repeat_time, + repeat_mode, + transpose=0, + context=None, +): + ctx = context if context is not None else _ods_ir.Context.current + i64 = _ods_ir.IntegerType.get_signless(64, ctx) + + def _i64(value): + return _ods_ir.IntegerAttr.get(i64, int(value)) + + return _ods_ir.DictAttr.get( + { + "fmap_h": _i64(fmap_h), + "fmap_w": _i64(fmap_w), + "pad_l": _i64(pad_l), + "pad_r": _i64(pad_r), + "pad_t": _i64(pad_t), + "pad_b": _i64(pad_b), + "stride_h": _i64(stride_h), + "stride_w": _i64(stride_w), + "dilation_h": _i64(dilation_h), + "dilation_w": _i64(dilation_w), + "filter_h": _i64(filter_h), + "filter_w": _i64(filter_w), + "channel_size": _i64(channel_size), + "dst_stride": _i64(dst_stride), + "dst_m_position": _i64(dst_m_position), + "repeat_stride": _i64(repeat_stride), + "repeat_time": _i64(repeat_time), + "repeat_mode": _i64(repeat_mode), + "transpose": _i64(transpose), + }, + context=ctx, + ) + + _PARTITION_VIEW_UNSET = object() _GeneratedPartitionViewOp = PartitionViewOp _TSCATTER_UNSET = object() diff --git a/test/lit/pto/timg2col_emitc.pto b/test/lit/pto/timg2col_emitc.pto new file mode 100644 index 000000000..6bfbdaae5 --- /dev/null +++ b/test/lit/pto/timg2col_emitc.pto @@ -0,0 +1,31 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s + +module { + func.func @timg2col_setup_only() attributes {pto.kernel_kind = #pto.kernel_kind} { + %src = pto.alloc_tile : !pto.tile_buf + pto.set_img2col_fmatrix %src : !pto.tile_buf + pto.set_img2col_padding %src : !pto.tile_buf + pto.set_img2col_repeat %src : !pto.tile_buf + return + } + + func.func @timg2col_basic() attributes {pto.kernel_kind = #pto.kernel_kind} { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.timg2col ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} + +// CHECK: static AICORE inline void PTOAS__TIMG2COL_SET_PADDING_ZERO() +// CHECK: static AICORE inline void PTOAS__TIMG2COL_SET_FMATRIX( +// CHECK: static AICORE inline void PTOAS__TIMG2COL_SET_RPT( +// CHECK-LABEL: AICORE void timg2col_setup_only( +// CHECK: PTOAS__TIMG2COL_SET_FMATRIX( +// CHECK: PTOAS__TIMG2COL_SET_PADDING_ZERO() +// CHECK: PTOAS__TIMG2COL_SET_RPT( +// CHECK-LABEL: AICORE void timg2col_basic( +// CHECK: Tile +// CHECK: TIMG2COL< +// CHECK: SetFmatrixMode::FMATRIX_A_MANUAL diff --git a/test/lit/pto/timg2col_frontend_lowering.pto b/test/lit/pto/timg2col_frontend_lowering.pto new file mode 100644 index 000000000..df1c533f3 --- /dev/null +++ b/test/lit/pto/timg2col_frontend_lowering.pto @@ -0,0 +1,20 @@ +// RUN: ptoas --pto-arch=a5 --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @timg2col_frontend_lowering() attributes {pto.kernel_kind = #pto.kernel_kind} { + %src = pto.alloc_tile : !pto.tile_buf + pto.set_img2col_fmatrix %src : !pto.tile_buf + pto.set_img2col_padding %src : !pto.tile_buf + pto.set_img2col_repeat %src : !pto.tile_buf + return + } +} + +// CHECK-LABEL: func.func @timg2col_frontend_lowering() +// CHECK: pto.pointer_cast(%c0_i64) {config = #pto.tile_buf_config, slayout=#pto.slayout +// CHECK: img2col_config={channel_size = 16 : i64 +// CHECK: repeat_time = 4 : i64 +// CHECK: stride_w = 2 : i64 +// CHECK: pto.set_img2col_fmatrix %{{.*}} : memref<16x32xf16 +// CHECK: pto.set_img2col_padding %{{.*}} : memref<16x32xf16 +// CHECK: pto.set_img2col_repeat %{{.*}} : memref<16x32xf16 diff --git a/test/lit/pto/timg2col_missing_config_invalid.pto b/test/lit/pto/timg2col_missing_config_invalid.pto new file mode 100644 index 000000000..092f90927 --- /dev/null +++ b/test/lit/pto/timg2col_missing_config_invalid.pto @@ -0,0 +1,13 @@ +// RUN: not ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s + +module { + func.func @timg2col_missing_config() attributes {pto.kernel_kind = #pto.kernel_kind} { + %src = pto.declare_tile -> !pto.tile_buf + %dst = pto.declare_tile -> !pto.tile_buf + pto.timg2col ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} + +// CHECK: error: expects src to carry img2col_config metadata