feat: add A5 timg2col IR and lowering#703
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for A5 img2col operations and configuration metadata. It adds new operations (timg2col, tsetfmatrix, tset_img2col_padding, and tset_img2col_rpt), extends TileBufConfigAttr to carry img2col_config dictionaries, and implements the corresponding lowering to EmitC. Two important issues were identified in the review: first, a logic error in the getRequired helper in PTOToEmitC.cpp where a fallback value causes required field validation to be bypassed; second, a potential overflow vulnerability in PTOAS__TIMG2COL_SET_RPT where input values are shifted and ORed without proper bit-masking.
| auto getRequired = [&](StringRef key, int64_t fallback = 0) | ||
| -> std::optional<int64_t> { | ||
| if (auto value = getImg2colConfigInt(dict, key)) | ||
| return value; | ||
| return fallback; | ||
| }; | ||
|
|
||
| Img2colConfigData data; | ||
| auto fmapH = getRequired("fmap_h"); | ||
| auto fmapW = getRequired("fmap_w"); | ||
| auto channelSize = getRequired("channel_size"); | ||
| auto repeatStride = getRequired("repeat_stride"); | ||
| auto repeatTime = getRequired("repeat_time"); | ||
| auto repeatMode = getRequired("repeat_mode"); | ||
| auto dstStride = getRequired("dst_stride"); | ||
| if (!fmapH || !fmapW || !channelSize || !repeatStride || !repeatTime || | ||
| !repeatMode || !dstStride) | ||
| return std::nullopt; | ||
|
|
||
| data.fmapH = *fmapH; | ||
| data.fmapW = *fmapW; | ||
| data.padL = getRequired("pad_l", 0).value_or(0); | ||
| data.padR = getRequired("pad_r", 0).value_or(0); | ||
| data.padT = getRequired("pad_t", 0).value_or(0); | ||
| data.padB = getRequired("pad_b", 0).value_or(0); | ||
| data.strideH = getRequired("stride_h", 1).value_or(1); | ||
| data.strideW = getRequired("stride_w", 1).value_or(1); | ||
| data.dilationH = getRequired("dilation_h", 1).value_or(1); | ||
| data.dilationW = getRequired("dilation_w", 1).value_or(1); | ||
| data.filterH = getRequired("filter_h", 1).value_or(1); | ||
| data.filterW = getRequired("filter_w", 1).value_or(1); | ||
| data.channelSize = *channelSize; | ||
| data.dstStride = *dstStride; | ||
| data.dstMPosition = getRequired("dst_m_position", 0).value_or(0); | ||
| data.repeatStride = *repeatStride; | ||
| data.repeatTime = *repeatTime; | ||
| data.repeatMode = *repeatMode; | ||
| data.transpose = getRequired("transpose", 0).value_or(0); | ||
| return data; |
There was a problem hiding this comment.
The getRequired helper function always returns an engaged std::optional because it falls back to fallback (defaulting to 0) wrapped in std::optional when the key is missing. This completely bypasses the required field validation check if (!fmapH || ...) on lines 521-523, which will never evaluate to true. We should split this into a strict getRequired helper (which returns std::optional without a fallback) and a getOptional helper (which returns int64_t with a fallback).
auto getRequired = [&](StringRef key) -> std::optional<int64_t> {
return getImg2colConfigInt(dict, key);
};
auto getOptional = [&](StringRef key, int64_t fallback) -> int64_t {
return getImg2colConfigInt(dict, key).value_or(fallback);
};
auto fmapH = getRequired("fmap_h");
auto fmapW = getRequired("fmap_w");
auto channelSize = getRequired("channel_size");
auto repeatStride = getRequired("repeat_stride");
auto repeatTime = getRequired("repeat_time");
auto repeatMode = getRequired("repeat_mode");
auto dstStride = getRequired("dst_stride");
if (!fmapH || !fmapW || !channelSize || !repeatStride || !repeatTime ||
!repeatMode || !dstStride)
return std::nullopt;
Img2colConfigData data;
data.fmapH = *fmapH;
data.fmapW = *fmapW;
data.padL = getOptional("pad_l", 0);
data.padR = getOptional("pad_r", 0);
data.padT = getOptional("pad_t", 0);
data.padB = getOptional("pad_b", 0);
data.strideH = getOptional("stride_h", 1);
data.strideW = getOptional("stride_w", 1);
data.dilationH = getOptional("dilation_h", 1);
data.dilationW = getOptional("dilation_w", 1);
data.filterH = getOptional("filter_h", 1);
data.filterW = getOptional("filter_w", 1);
data.channelSize = *channelSize;
data.dstStride = *dstStride;
data.dstMPosition = getOptional("dst_m_position", 0);
data.repeatStride = *repeatStride;
data.repeatTime = *repeatTime;
data.repeatMode = *repeatMode;
data.transpose = getOptional("transpose", 0);
return data;| template <bool UseB = false> | ||
| static AICORE inline void PTOAS__TIMG2COL_SET_RPT( | ||
| int32_t repeatStride, int32_t repeatTime, int32_t repeatMode, | ||
| int32_t dstStride, int32_t dstMposition) { | ||
| uint64_t rptConfig = 0; | ||
| constexpr uint32_t repeatTimeShiftBit = 16; | ||
| constexpr uint32_t repeatModeShiftBit = 24; | ||
| constexpr uint32_t dstStrideShiftBit = 32; | ||
| constexpr uint32_t dstMpositionShiftBit = 48; | ||
| rptConfig |= uint64_t(repeatStride); | ||
| rptConfig |= uint64_t(repeatTime) << repeatTimeShiftBit; | ||
| rptConfig |= uint64_t(repeatMode) << repeatModeShiftBit; | ||
| rptConfig |= uint64_t(dstStride) << dstStrideShiftBit; | ||
| rptConfig |= uint64_t(dstMposition) << dstMpositionShiftBit; | ||
| if constexpr (UseB) { | ||
| set_l3d_rpt_b(rptConfig); | ||
| } else { | ||
| set_l3d_rpt(rptConfig); | ||
| } | ||
| } |
There was a problem hiding this comment.
In PTOAS__TIMG2COL_SET_RPT, the values are shifted and ORed into rptConfig without any masking. If any of the input values (like repeatStride or repeatTime) exceed their expected bit-widths, they will overflow and clobber adjacent fields. We should apply explicit bit-masks (e.g., & 0xFFFF or & 0xFF) just like in PTOAS__TIMG2COL_SET_FMATRIX to ensure robustness.
| template <bool UseB = false> | |
| static AICORE inline void PTOAS__TIMG2COL_SET_RPT( | |
| int32_t repeatStride, int32_t repeatTime, int32_t repeatMode, | |
| int32_t dstStride, int32_t dstMposition) { | |
| uint64_t rptConfig = 0; | |
| constexpr uint32_t repeatTimeShiftBit = 16; | |
| constexpr uint32_t repeatModeShiftBit = 24; | |
| constexpr uint32_t dstStrideShiftBit = 32; | |
| constexpr uint32_t dstMpositionShiftBit = 48; | |
| rptConfig |= uint64_t(repeatStride); | |
| rptConfig |= uint64_t(repeatTime) << repeatTimeShiftBit; | |
| rptConfig |= uint64_t(repeatMode) << repeatModeShiftBit; | |
| rptConfig |= uint64_t(dstStride) << dstStrideShiftBit; | |
| rptConfig |= uint64_t(dstMposition) << dstMpositionShiftBit; | |
| if constexpr (UseB) { | |
| set_l3d_rpt_b(rptConfig); | |
| } else { | |
| set_l3d_rpt(rptConfig); | |
| } | |
| } | |
| template <bool UseB = false> | |
| static AICORE inline void PTOAS__TIMG2COL_SET_RPT( | |
| int32_t repeatStride, int32_t repeatTime, int32_t repeatMode, | |
| int32_t dstStride, int32_t dstMposition) { | |
| uint64_t rptConfig = 0; | |
| constexpr uint32_t repeatTimeShiftBit = 16; | |
| constexpr uint32_t repeatModeShiftBit = 24; | |
| constexpr uint32_t dstStrideShiftBit = 32; | |
| constexpr uint32_t dstMpositionShiftBit = 48; | |
| rptConfig |= uint64_t(repeatStride & 0xFFFF); | |
| rptConfig |= uint64_t(repeatTime & 0xFF) << repeatTimeShiftBit; | |
| rptConfig |= uint64_t(repeatMode & 0xFF) << repeatModeShiftBit; | |
| rptConfig |= uint64_t(dstStride & 0xFFFF) << dstStrideShiftBit; | |
| rptConfig |= uint64_t(dstMposition & 0xFFFF) << dstMpositionShiftBit; | |
| if constexpr (UseB) { | |
| set_l3d_rpt_b(rptConfig); | |
| } else { | |
| set_l3d_rpt(rptConfig); | |
| } | |
| } |
Codex Review该评论由 review 机器人自动更新。
SummaryPR #703 的 A5 timg2col lowering 还存在 3 个 correctness 问题:默认路径落到 manual 模式、多个必要 img2col 元数据在 EmitC 中被丢弃、以及参数在打包到寄存器字段前缺少范围校验。 Findings
当 fmatrixMode 缺省时,getImg2colMode() 直接返回 FMATRIX_A_MANUAL。随后 pto.timg2col 的 EmitC lowering 只会生成 TIMG2COL<..., FMATRIX_A_MANUAL>(dst, src, posM, posK),不会插入任何 PTOAS__TIMG2COL_SET_FMATRIX / SET_PADDING / SET_RPT 调用。这样一来,最常见的默认写法会依赖先前遗留的 img2col 状态;除非调用方总是手工在前面插入 3 个 setup op,否则这里会产生错误结果。
当前 lowering 只消费了 fmap_、pad_、repeat_*、dst_stride 和 dst_m_position。stride_h/stride_w、dilation_h/dilation_w、filter_h/filter_w、channel_size、transpose 虽然被 parser/verifier 接受,但没有任何一处进入生成的 C++。同时,getEmitCTileTypeString() 也只是把源 tile 序列化成普通的 Tile<...>,没有携带这些 img2col 元数据,所以即使切到 AUTO 模式也无处恢复它们。结果是:两个只在这些字段上不同的 IR,会 lower 成完全相同的代码,生成的 kernel 不可能实现请求的 img2col 配置。
verifyImg2colConfigLike() 目前只检查正数/非负数。后续 EmitC helper 会把 fmap_h/fmap_w 压进 16-bit 槽,把 pad_*、repeat_time、repeat_mode 压进 8-bit 槽,把 dst_stride/dst_m_position 压进 16-bit 槽;pto.timg2col 还会把 posM/posK 直接收窄成 uint16_t。超出这些位宽的 IR 值不会报错,只会被静默截断,甚至覆盖相邻 bitfield,最终让编译出来的硬件配置和 IR 语义不一致。 |
Fix issue: mouliangyu#400
Summary
This PR adds the missing A5
pto.timg2colfrontend/IR support and wires it through the current PTO->EmitC path.Instead of introducing a dedicated
ConvTiletype, this change models the img2col source tile as:!pto.tile_bufslayout=nc1hwc0orslayout=ndc1hwc0img2col_configmetadata carried bytile_buf_configWhat Changed
NC1HWC0/NDC1HWC0toSLayoutSetFmatrixMode/SetFmatrixModeAttrTileBufConfigAttrwith optionalimg2col_configTileBufTypeparse/print/accessors to carry img2col metadata fields:fmap_h,fmap_wpad_l,pad_r,pad_t,pad_bstride_h,stride_wdilation_h,dilation_wfilter_h,filter_wchannel_sizedst_stride,dst_m_positionrepeat_stride,repeat_time,repeat_modetransposepto.tsetfmatrixpto.tset_img2col_paddingpto.tset_img2col_rptpto.timg2colMemoryEffectsforpto.timg2colSLayoutvaluesSetFmatrixModeAttrTileBufConfigAttr(..., img2col_config=...)make_img2col_config(...)TIMG2COL<...>emissiondocs/PTO_IR_manual.mdValidation
Ran locally:
build/tools/ptoas/ptoas --pto-arch=a5 test/lit/pto/timg2col_missing_config_invalid.ptoimg2col_configbuild/tools/ptoas/ptoas --pto-arch=a5 --emit-pto-ir test/lit/pto/timg2col_frontend_lowering.ptobuild/tools/ptoas/ptoas --pto-arch=a5 test/lit/pto/timg2col_emitc.ptoSetFmatrixModeSetFmatrixModeAttrmake_img2col_configNotes
tile_buf + img2col_configto represent ConvTile-like inputs