Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 151 additions & 3 deletions docs/PTO_IR_manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,31 @@ A logical partition (slice) of a `tensor_view`. Holds shape and stride informati
| `v_row` | `int64` or `?` | Valid row count |
| `v_col` | `int64` or `?` | Valid column count |
| `blayout` | `BLayout` mnemonic | Base layout (`row_major` / `col_major`) |
| `slayout` | `SLayout` mnemonic | Secondary layout (`none_box` / `row_major` / `col_major`) |
| `slayout` | `SLayout` mnemonic | Secondary layout (`none_box` / `row_major` / `col_major` / `nc1hwc0` / `ndc1hwc0`) |
| `fractal` | `int32` | Fractal size |
| `pad` | `PadValue` mnemonic or integer literal | Padding policy/value selector (tests commonly use `pad=0`) |
| `fmap_h/pad_l/.../repeat_mode/transpose` | `int64` literals (optional) | Extra img2col metadata used when `slayout=nc1hwc0` or `ndc1hwc0` |

Here, `?` denotes a dynamic symbol resolved at runtime.

For `dtype=!pto.f4E1M2x2` and `dtype=!pto.f4E2M1x2`, the `rows`/`cols` and
`v_row`/`v_col` values are physical packed extents. In other words, the packed
dimension counts FP4 pairs stored per byte, not logical scalar FP4 elements.

For A5 img2col authoring, PTO IR does not introduce a dedicated `ConvTile`
type. Instead, a MAT `tile_buf` with `slayout=nc1hwc0` or `slayout=ndc1hwc0`
plus img2col metadata fields represents the ConvTile surface:

- `fmap_h`, `fmap_w`
- `pad_l`, `pad_r`, `pad_t`, `pad_b`
- `stride_h`, `stride_w`
- `dilation_h`, `dilation_w`
- `filter_h`, `filter_w`
- `channel_size`
- `dst_stride`, `dst_m_position`
- `repeat_stride`, `repeat_time`, `repeat_mode`
- `transpose`

**Syntax:**
```mlir
!pto.tile_buf<loc=vec, dtype=f16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
Expand Down Expand Up @@ -267,12 +282,16 @@ Composite attribute and component enums for tile buffer configuration.
| Parameter | Type | Description |
|-----------|------|-------------|
| `bLayout` | `BLayoutAttr` | Base layout (RowMajor / ColMajor) |
| `sLayout` | `SLayoutAttr` | Secondary layout (NoneBox / RowMajor / ColMajor) |
| `sLayout` | `SLayoutAttr` | Secondary layout (NoneBox / RowMajor / ColMajor / NC1HWC0 / NDC1HWC0) |
| `sFractalSize` | `IntegerAttr (i32)` | Secondary fractal size |
| `pad` | `PadValueAttr` | Pad value policy |
| `img2colConfig` | `DictionaryAttr` (optional) | Lowered img2col metadata dictionary for ConvTile-like MAT tiles

**Syntax:** `#pto.tile_buf_config<row_major, none_box, 16, zero>`

For img2col-lowered MAT tiles, `img2colConfig` is carried as
`img2col_config={...}` inside `#pto.tile_buf_config<...>`.

**BLayout** (Base layout):

| Value | Int | Mnemonic |
Expand All @@ -287,6 +306,8 @@ Composite attribute and component enums for tile buffer configuration.
| `NoneBox` | 0 | `none_box` |
| `RowMajor` | 1 | `row_major` |
| `ColMajor` | 2 | `col_major` |
| `NC1HWC0` | 3 | `nc1hwc0` |
| `NDC1HWC0` | 4 | `ndc1hwc0` |

**PadValue** (Pad value policy):

Expand All @@ -299,7 +320,24 @@ Composite attribute and component enums for tile buffer configuration.

---

### 3.5 Layout
### 3.5 SetFmatrixMode

Selects whether img2col uses the A-register bank or B-register bank, and
whether the `TIMG2COL` wrapper auto-programs the auxiliary fmatrix/repeat/padding
registers.

| Value | Int | Mnemonic |
|-------|-----|----------|
| `FMATRIX_A_AUTO` | 0 | `fmatrix_a_auto` |
| `FMATRIX_B_AUTO` | 1 | `fmatrix_b_auto` |
| `FMATRIX_A_MANUAL` | 2 | `fmatrix_a_manual` |
| `FMATRIX_B_MANUAL` | 3 | `fmatrix_b_manual` |

**Attribute syntax:** `#pto.set_fmatrix_mode<fmatrix_a_auto>`

---

### 3.6 Layout

Global tensor layout inference for [`tensor_view` (Section 2.3)](#23-ptotensor_viewd0-x-d1-x-elementtype)/[`partition_tensor_view` (Section 2.4)](#24-ptopartition_tensor_viewd0-x-d1-x-elementtype). Tile buffers additionally use **Tile Buf config** (see 3.4) to describe physical/fractal layout.

Expand Down Expand Up @@ -7390,6 +7428,116 @@ pto.tfillpad_inplace ins(%tile : !pto.tile_buf<loc=vec, dtype=f32, rows=32, cols

---

##### `pto.set_img2col_fmatrix` - Program Img2col Fmatrix State

**Summary:** Programs the img2col fmatrix register from ConvTile-like MAT tile metadata.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `config` | `pto.tile_buf` | MAT tile with `slayout=nc1hwc0` or `ndc1hwc0` and img2col metadata |

**Results:** None.

**Constraints & Verification:**

- A5 only.
- `config.loc` must be `mat`.
- `config.blayout` must be `row_major`.
- `config.slayout` must be `nc1hwc0` or `ndc1hwc0`.
- `config` must carry img2col metadata (`fmap_*`, `pad_*`, `stride_*`, `dilation_*`, `filter_*`, `channel_size`, `dst_stride`, `repeat_*`, `transpose`).

**Hardware Mapping:**

- EmitC lowers to `set_fmatrix` / `set_fmatrix_b` via a helper wrapper.
- Executes on the **FIX pipeline** (`PIPE_FIX`).

**Basic Example:**

```mlir
pto.set_img2col_fmatrix %src : !pto.tile_buf<mat, 16x32xf16, slayout=nc1hwc0, pad=1, fmap_h=8, fmap_w=8, pad_l=1, pad_r=1, pad_t=1, pad_b=1, stride_h=1, stride_w=1, dilation_h=1, dilation_w=1, filter_h=3, filter_w=3, channel_size=16, dst_stride=16, repeat_stride=16, repeat_time=4, repeat_mode=0>
```

---

##### `pto.set_img2col_padding` - Program Img2col Padding State

**Summary:** Programs the img2col padding register from ConvTile-like MAT tile metadata.

**Arguments:** Same as `pto.set_img2col_fmatrix`.

**Results:** None.

**Constraints & Verification:**

- Same structural constraints as `pto.set_img2col_fmatrix`.
- Current PTOAS implementation accepts `pad=null` and `pad=zero`; `pad=zero` lowers to `set_padding(0)` / `set_padding_b(0)`.

**Hardware Mapping:**

- EmitC lowers to `set_padding` / `set_padding_b` via a helper wrapper.
- Executes on the **FIX pipeline** (`PIPE_FIX`).

---

##### `pto.set_img2col_repeat` - Program Img2col Repeat State

**Summary:** Programs the img2col repeat register from ConvTile-like MAT tile metadata.

**Arguments:** Same as `pto.set_img2col_fmatrix`.

**Results:** None.

**Constraints & Verification:** Same structural constraints as `pto.set_img2col_fmatrix`.

**Hardware Mapping:**

- EmitC lowers to `set_l3d_rpt` / `set_l3d_rpt_b` via a helper wrapper.
- Executes on the **FIX pipeline** (`PIPE_FIX`).

---

##### `pto.timg2col` - Convert ConvTile to Left Tile

**Summary:** Performs the img2col transform from a ConvTile-like MAT tile into a LEFT tile suitable for cube/matmul consumption.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `src` | `pto.tile_buf` | Source MAT tile with `slayout=nc1hwc0` or `ndc1hwc0` and img2col metadata |
| `dst` | `pto.tile_buf` | Destination LEFT tile |
| `posM` | `i32` attribute | Optional M-position, default `0` |
| `posK` | `i32` attribute | Optional K-position, default `0` |
| `fmatrixMode` | `SetFmatrixModeAttr` | Optional mode controlling A/B bank selection and auto/manual helper behavior |

**Results:** None. Writes into `dst` via DPS pattern.

**Constraints & Verification:**

- A5 only.
- `src.loc` must be `mat`; `dst.loc` must be `left`.
- `src.blayout` must be `row_major`.
- `src.slayout` must be `nc1hwc0` or `ndc1hwc0`.
- `dst.blayout` must be `col_major`; `dst.slayout` must be `row_major`.
- `src` and `dst` element types must match and must be one of `i8/i16/i32/f16/bf16/f32`.

**Hardware Mapping:**

- EmitC lowers to `TIMG2COL<dstTile, srcTile, SetFmatrixMode>(dst, src, posM, posK)`.
- The actual hardware instruction is `img2colv2_cbuf_to_ca`.
- Executes on the **MTE1 pipeline** (`PIPE_MTE1`).

**Basic Example:**

```mlir
pto.timg2col ins(%src : !pto.tile_buf<mat, 16x32xf16, slayout=nc1hwc0, pad=1, fmap_h=8, fmap_w=8, pad_l=1, pad_r=1, pad_t=1, pad_b=1, stride_h=1, stride_w=1, dilation_h=1, dilation_w=1, filter_h=3, filter_w=3, channel_size=16, dst_stride=16, repeat_stride=16, repeat_time=4, repeat_mode=0>)
outs(%dst : !pto.tile_buf<left, 16x144xf16, slayout=row_major>)
```

---

### 4.11 Sorting Operations

##### `pto.tsort32` - Sort Fixed 32-Element Blocks
Expand Down
35 changes: 32 additions & 3 deletions include/PTO/IR/PTOAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,9 @@ def PTO_BLayoutAttr : PTO_Attr<"BLayout", "blayout"> {
def PTO_SLayout_Enum : PTO_I32Enum<"SLayout", "Secondary layout", [
I32EnumAttrCase<"NoneBox", 0, "none_box">,
I32EnumAttrCase<"RowMajor", 1, "row_major">,
I32EnumAttrCase<"ColMajor", 2, "col_major">
I32EnumAttrCase<"ColMajor", 2, "col_major">,
I32EnumAttrCase<"NC1HWC0", 3, "nc1hwc0">,
I32EnumAttrCase<"NDC1HWC0", 4, "ndc1hwc0">
]>;

def PTO_SLayoutAttr : PTO_Attr<"SLayout", "slayout"> {
Expand Down Expand Up @@ -629,6 +631,18 @@ def PTO_CompactModeAttr : PTO_Attr<"CompactMode", "compact_mode"> {
let assemblyFormat = "`<` params `>`";
}

def PTO_SetFmatrixMode_Enum : PTO_I32Enum<"SetFmatrixMode", "Img2col fmatrix programming mode", [
I32EnumAttrCase<"FMATRIX_A_AUTO", 0, "fmatrix_a_auto">,
I32EnumAttrCase<"FMATRIX_B_AUTO", 1, "fmatrix_b_auto">,
I32EnumAttrCase<"FMATRIX_A_MANUAL", 2, "fmatrix_a_manual">,
I32EnumAttrCase<"FMATRIX_B_MANUAL", 3, "fmatrix_b_manual">
]>;

def PTO_SetFmatrixModeAttr : PTO_Attr<"SetFmatrixMode", "set_fmatrix_mode"> {
let parameters = (ins EnumParameter<PTO_SetFmatrixMode_Enum>:$value);
let assemblyFormat = "`<` params `>`";
}

// ---------- tile_buf_config (NO b_fractal / s_fractal) ----------
def TileBufConfigAttr : AttrDef<PTO_Dialect, "TileBufConfig"> {
let mnemonic = "tile_buf_config";
Expand All @@ -637,7 +651,8 @@ def TileBufConfigAttr : AttrDef<PTO_Dialect, "TileBufConfig"> {
"SLayoutAttr":$sLayout,
"mlir::IntegerAttr":$sFractalSize, // i32
"PadValueAttr":$pad,
"CompactModeAttr":$compactMode
"CompactModeAttr":$compactMode,
"mlir::Attribute":$img2colConfig
);

let hasCustomAssemblyFormat = 1;
Expand All @@ -649,19 +664,33 @@ def TileBufConfigAttr : AttrDef<PTO_Dialect, "TileBufConfig"> {
"mlir::IntegerAttr":$sFractalSize,
"PadValueAttr":$pad,
"CompactModeAttr":$compactMode
), [{
return Base::get($_ctxt, bLayout, sLayout, sFractalSize, pad, compactMode,
mlir::Attribute());
}]>,
AttrBuilder<(ins
"BLayoutAttr":$bLayout,
"SLayoutAttr":$sLayout,
"mlir::IntegerAttr":$sFractalSize,
"PadValueAttr":$pad,
"CompactModeAttr":$compactMode,
"mlir::Attribute":$img2colConfig
)>
];

let extraClassDeclaration = [{
static TileBufConfigAttr getDefault(MLIRContext *ctx);
bool isDefault() const;
mlir::DictionaryAttr getImg2colConfigDict() const;
bool hasImg2colConfig() const;

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
mlir::Attribute bLayout,
mlir::Attribute sLayout,
mlir::IntegerAttr sFractalSize,
mlir::Attribute pad,
mlir::Attribute compactMode);
mlir::Attribute compactMode,
mlir::Attribute img2colConfig);
}];
}

Expand Down
100 changes: 100 additions & 0 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4225,6 +4225,106 @@ def TFillPadOp : PTO_TOp<"tfillpad", [
}];
}

def SetImg2colFmatrixOp : PTO_Op<"set_img2col_fmatrix", [
OpPipeInterface
]> {
let summary = "Program the img2col fmatrix state from tile metadata";

let arguments = (ins
PTODpsType:$config,
OptionalAttr<PTO_SetFmatrixModeAttr>:$fmatrixMode
);

let results = (outs);

let hasVerifier = 1;

let assemblyFormat = [{
$config attr-dict `:` qualified(type($config))
}];

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_FIX; }
}];
}

def SetImg2colPaddingOp : PTO_Op<"set_img2col_padding", [
OpPipeInterface
]> {
let summary = "Program the img2col padding value from tile metadata";

let arguments = (ins
PTODpsType:$config,
OptionalAttr<PTO_SetFmatrixModeAttr>:$fmatrixMode
);

let results = (outs);

let hasVerifier = 1;

let assemblyFormat = [{
$config attr-dict `:` qualified(type($config))
}];

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_FIX; }
}];
}

def SetImg2colRepeatOp : PTO_Op<"set_img2col_repeat", [
OpPipeInterface
]> {
let summary = "Program the img2col repeat state from tile metadata";

let arguments = (ins
PTODpsType:$config,
OptionalAttr<PTO_SetFmatrixModeAttr>:$fmatrixMode
);

let results = (outs);

let hasVerifier = 1;

let assemblyFormat = [{
$config attr-dict `:` qualified(type($config))
}];

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_FIX; }
}];
}

def TImg2colOp : PTO_TOp<"timg2col", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Perform img2col from MAT source into LEFT destination";

let arguments = (ins
PTODpsType:$src,
PTODpsType:$dst,
DefaultValuedAttr<I32Attr, "0">:$posM,
DefaultValuedAttr<I32Attr, "0">:$posK,
OptionalAttr<PTO_SetFmatrixModeAttr>:$fmatrixMode
);

let results = (outs);

let hasVerifier = 1;

let assemblyFormat = [{
`ins` `(` $src `:` qualified(type($src)) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
attr-dict
}];

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_MTE1; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}

def TFillPadExpandOp : PTO_TOp<"tfillpad_expand", [
PTO_DpsInitOpInterface,
OpPipeInterface,
Expand Down
4 changes: 3 additions & 1 deletion include/PTO/IR/PTOTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,12 @@ def TileBufType : TypeDef<PTO_Dialect, "TileBuf"> {
int32_t getSFractalSizeI32() const;
mlir::Attribute getPadValueAttr() const;
mlir::Attribute getCompactModeAttr() const;
mlir::DictionaryAttr getImg2colConfigAttr() const;
bool hasImg2colConfig() const;

// 如果你仍然想要“数值枚举”,就提供 int getter(不会依赖 enum 类型)
int32_t getBLayoutValueI32() const; // 0 row_major, 1 col_major
int32_t getSLayoutValueI32() const; // 0 none_box, 1 row_major, 2 col_major
int32_t getSLayoutValueI32() const; // 0 none_box, 1 row_major, 2 col_major, 3 nc1hwc0, 4 ndc1hwc0
int32_t getPadValueI32() const; // 0 null, 1 zero, 2 max, 3 min
int32_t getCompactModeI32() const; // 0 null, 1 normal, 2 row_plus_one
}];
Expand Down
Loading