Skip to content

feat: add A5 timg2col IR and lowering#703

Open
TelGome wants to merge 2 commits into
hw-native-sys:mainfrom
TelGome:issue-400-timg2col-a5
Open

feat: add A5 timg2col IR and lowering#703
TelGome wants to merge 2 commits into
hw-native-sys:mainfrom
TelGome:issue-400-timg2col-a5

Conversation

@TelGome
Copy link
Copy Markdown

@TelGome TelGome commented May 26, 2026

Fix issue: mouliangyu#400

Summary

This PR adds the missing A5 pto.timg2col frontend/IR support and wires it through the current PTO->EmitC path.

Instead of introducing a dedicated ConvTile type, this change models the img2col source tile as:

  • MAT !pto.tile_buf
  • slayout=nc1hwc0 or slayout=ndc1hwc0
  • extra img2col_config metadata carried by tile_buf_config

What Changed

  • add NC1HWC0 / NDC1HWC0 to SLayout
  • add SetFmatrixMode / SetFmatrixModeAttr
  • extend TileBufConfigAttr with optional img2col_config
  • extend TileBufType parse/print/accessors to carry img2col metadata fields:
    • fmap_h, fmap_w
    • pad_l, pad_r, pad_t, pad_b
    • stride_h, stride_w
    • dilation_h, dilation_w
    • filter_h, filter_w
    • channel_size
    • dst_stride, dst_m_position
    • repeat_stride, repeat_time, repeat_mode
    • transpose
  • add new PTO ops:
    • pto.tsetfmatrix
    • pto.tset_img2col_padding
    • pto.tset_img2col_rpt
    • pto.timg2col
  • add verifier rules for A5 img2col source/destination constraints
  • add MemoryEffects for pto.timg2col
  • update plan-memory handling for img2col setup ops
  • add C API / Python bindings for:
    • new SLayout values
    • SetFmatrixModeAttr
    • TileBufConfigAttr(..., img2col_config=...)
    • Python helper make_img2col_config(...)
  • add EmitC lowering for:
    • fmatrix setup
    • padding setup
    • repeat setup
    • TIMG2COL<...> emission
  • document the new IR surface in docs/PTO_IR_manual.md

Validation

Ran locally:

  • build/tools/ptoas/ptoas --pto-arch=a5 test/lit/pto/timg2col_missing_config_invalid.pto
    • expected verifier failure: missing img2col_config
  • build/tools/ptoas/ptoas --pto-arch=a5 --emit-pto-ir test/lit/pto/timg2col_frontend_lowering.pto
  • build/tools/ptoas/ptoas --pto-arch=a5 test/lit/pto/timg2col_emitc.pto
  • Python import/binding smoke check for:
    • SetFmatrixMode
    • SetFmatrixModeAttr
    • make_img2col_config

Notes

  • this PR uses tile_buf + img2col_config to represent ConvTile-like inputs
  • this PR focuses on the PTO IR / EmitC path
  • VPTO-specific support is not added here

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for A5 img2col operations and configuration metadata. It adds new operations (timg2col, tsetfmatrix, tset_img2col_padding, and tset_img2col_rpt), extends TileBufConfigAttr to carry img2col_config dictionaries, and implements the corresponding lowering to EmitC. Two important issues were identified in the review: first, a logic error in the getRequired helper in PTOToEmitC.cpp where a fallback value causes required field validation to be bypassed; second, a potential overflow vulnerability in PTOAS__TIMG2COL_SET_RPT where input values are shifted and ORed without proper bit-masking.

Comment on lines +506 to +544
auto getRequired = [&](StringRef key, int64_t fallback = 0)
-> std::optional<int64_t> {
if (auto value = getImg2colConfigInt(dict, key))
return value;
return fallback;
};

Img2colConfigData data;
auto fmapH = getRequired("fmap_h");
auto fmapW = getRequired("fmap_w");
auto channelSize = getRequired("channel_size");
auto repeatStride = getRequired("repeat_stride");
auto repeatTime = getRequired("repeat_time");
auto repeatMode = getRequired("repeat_mode");
auto dstStride = getRequired("dst_stride");
if (!fmapH || !fmapW || !channelSize || !repeatStride || !repeatTime ||
!repeatMode || !dstStride)
return std::nullopt;

data.fmapH = *fmapH;
data.fmapW = *fmapW;
data.padL = getRequired("pad_l", 0).value_or(0);
data.padR = getRequired("pad_r", 0).value_or(0);
data.padT = getRequired("pad_t", 0).value_or(0);
data.padB = getRequired("pad_b", 0).value_or(0);
data.strideH = getRequired("stride_h", 1).value_or(1);
data.strideW = getRequired("stride_w", 1).value_or(1);
data.dilationH = getRequired("dilation_h", 1).value_or(1);
data.dilationW = getRequired("dilation_w", 1).value_or(1);
data.filterH = getRequired("filter_h", 1).value_or(1);
data.filterW = getRequired("filter_w", 1).value_or(1);
data.channelSize = *channelSize;
data.dstStride = *dstStride;
data.dstMPosition = getRequired("dst_m_position", 0).value_or(0);
data.repeatStride = *repeatStride;
data.repeatTime = *repeatTime;
data.repeatMode = *repeatMode;
data.transpose = getRequired("transpose", 0).value_or(0);
return data;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The getRequired helper function always returns an engaged std::optional because it falls back to fallback (defaulting to 0) wrapped in std::optional when the key is missing. This completely bypasses the required field validation check if (!fmapH || ...) on lines 521-523, which will never evaluate to true. We should split this into a strict getRequired helper (which returns std::optional without a fallback) and a getOptional helper (which returns int64_t with a fallback).

  auto getRequired = [&](StringRef key) -> std::optional<int64_t> {
    return getImg2colConfigInt(dict, key);
  };
  auto getOptional = [&](StringRef key, int64_t fallback) -> int64_t {
    return getImg2colConfigInt(dict, key).value_or(fallback);
  };

  auto fmapH = getRequired("fmap_h");
  auto fmapW = getRequired("fmap_w");
  auto channelSize = getRequired("channel_size");
  auto repeatStride = getRequired("repeat_stride");
  auto repeatTime = getRequired("repeat_time");
  auto repeatMode = getRequired("repeat_mode");
  auto dstStride = getRequired("dst_stride");
  if (!fmapH || !fmapW || !channelSize || !repeatStride || !repeatTime ||
      !repeatMode || !dstStride)
    return std::nullopt;

  Img2colConfigData data;
  data.fmapH = *fmapH;
  data.fmapW = *fmapW;
  data.padL = getOptional("pad_l", 0);
  data.padR = getOptional("pad_r", 0);
  data.padT = getOptional("pad_t", 0);
  data.padB = getOptional("pad_b", 0);
  data.strideH = getOptional("stride_h", 1);
  data.strideW = getOptional("stride_w", 1);
  data.dilationH = getOptional("dilation_h", 1);
  data.dilationW = getOptional("dilation_w", 1);
  data.filterH = getOptional("filter_h", 1);
  data.filterW = getOptional("filter_w", 1);
  data.channelSize = *channelSize;
  data.dstStride = *dstStride;
  data.dstMPosition = getOptional("dst_m_position", 0);
  data.repeatStride = *repeatStride;
  data.repeatTime = *repeatTime;
  data.repeatMode = *repeatMode;
  data.transpose = getOptional("transpose", 0);
  return data;

Comment on lines +12924 to +12943
template <bool UseB = false>
static AICORE inline void PTOAS__TIMG2COL_SET_RPT(
int32_t repeatStride, int32_t repeatTime, int32_t repeatMode,
int32_t dstStride, int32_t dstMposition) {
uint64_t rptConfig = 0;
constexpr uint32_t repeatTimeShiftBit = 16;
constexpr uint32_t repeatModeShiftBit = 24;
constexpr uint32_t dstStrideShiftBit = 32;
constexpr uint32_t dstMpositionShiftBit = 48;
rptConfig |= uint64_t(repeatStride);
rptConfig |= uint64_t(repeatTime) << repeatTimeShiftBit;
rptConfig |= uint64_t(repeatMode) << repeatModeShiftBit;
rptConfig |= uint64_t(dstStride) << dstStrideShiftBit;
rptConfig |= uint64_t(dstMposition) << dstMpositionShiftBit;
if constexpr (UseB) {
set_l3d_rpt_b(rptConfig);
} else {
set_l3d_rpt(rptConfig);
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In PTOAS__TIMG2COL_SET_RPT, the values are shifted and ORed into rptConfig without any masking. If any of the input values (like repeatStride or repeatTime) exceed their expected bit-widths, they will overflow and clobber adjacent fields. We should apply explicit bit-masks (e.g., & 0xFFFF or & 0xFF) just like in PTOAS__TIMG2COL_SET_FMATRIX to ensure robustness.

Suggested change
template <bool UseB = false>
static AICORE inline void PTOAS__TIMG2COL_SET_RPT(
int32_t repeatStride, int32_t repeatTime, int32_t repeatMode,
int32_t dstStride, int32_t dstMposition) {
uint64_t rptConfig = 0;
constexpr uint32_t repeatTimeShiftBit = 16;
constexpr uint32_t repeatModeShiftBit = 24;
constexpr uint32_t dstStrideShiftBit = 32;
constexpr uint32_t dstMpositionShiftBit = 48;
rptConfig |= uint64_t(repeatStride);
rptConfig |= uint64_t(repeatTime) << repeatTimeShiftBit;
rptConfig |= uint64_t(repeatMode) << repeatModeShiftBit;
rptConfig |= uint64_t(dstStride) << dstStrideShiftBit;
rptConfig |= uint64_t(dstMposition) << dstMpositionShiftBit;
if constexpr (UseB) {
set_l3d_rpt_b(rptConfig);
} else {
set_l3d_rpt(rptConfig);
}
}
template <bool UseB = false>
static AICORE inline void PTOAS__TIMG2COL_SET_RPT(
int32_t repeatStride, int32_t repeatTime, int32_t repeatMode,
int32_t dstStride, int32_t dstMposition) {
uint64_t rptConfig = 0;
constexpr uint32_t repeatTimeShiftBit = 16;
constexpr uint32_t repeatModeShiftBit = 24;
constexpr uint32_t dstStrideShiftBit = 32;
constexpr uint32_t dstMpositionShiftBit = 48;
rptConfig |= uint64_t(repeatStride & 0xFFFF);
rptConfig |= uint64_t(repeatTime & 0xFF) << repeatTimeShiftBit;
rptConfig |= uint64_t(repeatMode & 0xFF) << repeatModeShiftBit;
rptConfig |= uint64_t(dstStride & 0xFFFF) << dstStrideShiftBit;
rptConfig |= uint64_t(dstMposition & 0xFFFF) << dstMpositionShiftBit;
if constexpr (UseB) {
set_l3d_rpt_b(rptConfig);
} else {
set_l3d_rpt(rptConfig);
}
}

@reedhecre
Copy link
Copy Markdown

reedhecre commented May 26, 2026

Codex Review

该评论由 review 机器人自动更新。

  • PR: feat: add A5 timg2col IR and lowering #703 feat: add A5 timg2col IR and lowering
  • Author: TelGome
  • Base/Head: main / issue-400-timg2col-a5
  • Head SHA: 101f797c1f5a
  • Trigger: PR 有新提交
  • Generated At: 2026-05-27T08:30:05Z
  • Previous Head SHA: 73acf4c5fd54
  • Status: completed

Summary

PR #703 的 A5 timg2col lowering 还存在 3 个 correctness 问题:默认路径落到 manual 模式、多个必要 img2col 元数据在 EmitC 中被丢弃、以及参数在打包到寄存器字段前缺少范围校验。

Findings

  1. P1 省略 fmatrixMode 时,timg2col 会默认生成 manual 模式且不做寄存器初始化 lib/PTO/Transforms/PTOToEmitC.cpp:554

当 fmatrixMode 缺省时,getImg2colMode() 直接返回 FMATRIX_A_MANUAL。随后 pto.timg2col 的 EmitC lowering 只会生成 TIMG2COL<..., FMATRIX_A_MANUAL>(dst, src, posM, posK),不会插入任何 PTOAS__TIMG2COL_SET_FMATRIX / SET_PADDING / SET_RPT 调用。这样一来,最常见的默认写法会依赖先前遗留的 img2col 状态;除非调用方总是手工在前面插入 3 个 setup op,否则这里会产生错误结果。

  1. P1 多个必需的 img2col 元数据在 EmitC lowering 中被直接丢弃 lib/PTO/Transforms/PTOToEmitC.cpp:8518

当前 lowering 只消费了 fmap_、pad_、repeat_*、dst_stride 和 dst_m_position。stride_h/stride_w、dilation_h/dilation_w、filter_h/filter_w、channel_size、transpose 虽然被 parser/verifier 接受,但没有任何一处进入生成的 C++。同时,getEmitCTileTypeString() 也只是把源 tile 序列化成普通的 Tile<...>,没有携带这些 img2col 元数据,所以即使切到 AUTO 模式也无处恢复它们。结果是:两个只在这些字段上不同的 IR,会 lower 成完全相同的代码,生成的 kernel 不可能实现请求的 img2col 配置。

  1. P2 verifier 没有对会被压进 8/16 bit 字段的参数做范围校验 lib/PTO/IR/PTO.cpp:1584

verifyImg2colConfigLike() 目前只检查正数/非负数。后续 EmitC helper 会把 fmap_h/fmap_w 压进 16-bit 槽,把 pad_*、repeat_time、repeat_mode 压进 8-bit 槽,把 dst_stride/dst_m_position 压进 16-bit 槽;pto.timg2col 还会把 posM/posK 直接收窄成 uint16_t。超出这些位宽的 IR 值不会报错,只会被静默截断,甚至覆盖相邻 bitfield,最终让编译出来的硬件配置和 IR 语义不一致。

Comment thread docs/PTO_IR_manual.md Outdated
@TelGome TelGome requested a review from Zhendong404 May 27, 2026 08:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants