The performance of BOO is 551us vs Fusilli 1031us for the following IREE input.
BOO:
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
module {
util.func public @fused_op_layout_cus_ce99c970a1fbf4682883ac5fef7d1b047d72a5c2_16x1x256x2048xbfloat16_16x1x512x2048xbfloat16_2048x1x2x2048xbfloat16$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%cst = arith.constant 0.000000e+00 : f32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : bf16
%cst_1 = arith.constant dense<0.000000e+00> : tensor<16x1x513x2048xbf16>
%0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<16x1x256x2048xbf16>
%1 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<2048x1x2x2048xbf16>
%2 = tensor.empty() : tensor<2048x1x2x2048xbf16>
%3 = linalg.fill ins(%cst_0 : bf16) outs(%2 : tensor<2048x1x2x2048xbf16>) -> tensor<2048x1x2x2048xbf16>
%4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<2048x1x2x2048xbf16>) outs(%3 : tensor<2048x1x2x2048xbf16>) {
^bb0(%in: bf16, %out: bf16):
%12 = linalg.index 0 : index
%13 = linalg.index 2 : index
%14 = linalg.index 3 : index
%15 = arith.subi %c1, %13 : index
%extracted = tensor.extract %1[%12, %c0, %15, %14] : tensor<2048x1x2x2048xbf16>
linalg.yield %extracted : bf16
} -> tensor<2048x1x2x2048xbf16>
%inserted_slice = tensor.insert_slice %0 into %cst_1[0, 0, 1, 0] [16, 1, 256, 2048] [1, 1, 2, 1] : tensor<16x1x256x2048xbf16> into tensor<16x1x513x2048xbf16>
%5 = tensor.empty() : tensor<16x1x512x2048xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<16x1x512x2048xf32>) -> tensor<16x1x512x2048xf32>
%7 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%inserted_slice, %4 : tensor<16x1x513x2048xbf16>, tensor<2048x1x2x2048xbf16>) outs(%6 : tensor<16x1x512x2048xf32>) {
^bb0(%in: bf16, %in_2: bf16, %out: f32):
%12 = arith.extf %in : bf16 to f32
%13 = arith.extf %in_2 : bf16 to f32
%14 = arith.mulf %12, %13 : f32
%15 = arith.addf %out, %14 : f32
linalg.yield %15 : f32
} -> tensor<16x1x512x2048xf32>
%8 = tensor.empty() : tensor<16x1x512x2048xbf16>
%9 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%7 : tensor<16x1x512x2048xf32>) outs(%8 : tensor<16x1x512x2048xbf16>) {
^bb0(%in: f32, %out: bf16):
%12 = arith.truncf %in : f32 to bf16
linalg.yield %12 : bf16
} -> tensor<16x1x512x2048xbf16>
%10 = hal.tensor.barrier join(%9 : tensor<16x1x512x2048xbf16>) => %arg4 : !hal.fence
%11 = hal.tensor.export %10 : tensor<16x1x512x2048xbf16> -> !hal.buffer_view
util.return %11 : !hal.buffer_view
}
util.func public @fused_op_layout_cus_ce99c970a1fbf4682883ac5fef7d1b047d72a5c2_16x1x256x2048xbfloat16_16x1x512x2048xbfloat16_2048x1x2x2048xbfloat16(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%0 = util.null : !hal.fence
%c-1_i32 = arith.constant -1 : i32
%c0 = arith.constant 0 : index
%device_0 = hal.devices.get %c0 : !hal.device
%fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
%1 = util.call @fused_op_layout_cus_ce99c970a1fbf4682883ac5fef7d1b047d72a5c2_16x1x256x2048xbfloat16_16x1x512x2048xbfloat16_2048x1x2x2048xbfloat16$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
%status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
util.return %1 : !hal.buffer_view
}
}
Fusilli:
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module @module {
util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%cst = arith.constant 0.000000e+00 : bf16
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.tensor.import wait(%arg3) => %arg1 : !hal.buffer_view -> tensor<16x256x1x2048xbf16>
%1 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<2048x2x1x2048xbf16>
%2 = tensor.empty() : tensor<16x2048x1x256xbf16>
%transposed = linalg.transpose ins(%0 : tensor<16x256x1x2048xbf16>) outs(%2 : tensor<16x2048x1x256xbf16>) permutation = [0, 3, 2, 1]
%3 = tensor.empty() : tensor<2048x2048x1x2xbf16>
%transposed_1 = linalg.transpose ins(%1 : tensor<2048x2x1x2048xbf16>) outs(%3 : tensor<2048x2048x1x2xbf16>) permutation = [0, 3, 2, 1]
%4 = tensor.empty() : tensor<16x2048x1x512xbf16>
%5 = tensor.empty() : tensor<16x2048x1x513xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<16x2048x1x513xbf16>) -> tensor<16x2048x1x513xbf16>
%inserted_slice = tensor.insert_slice %transposed into %6[0, 0, 0, 1] [16, 2048, 1, 256] [1, 1, 1, 2] : tensor<16x2048x1x256xbf16> into tensor<16x2048x1x513xbf16>
%7 = tensor.empty() : tensor<16x2048x1x512xf32>
%8 = linalg.fill ins(%cst_0 : f32) outs(%7 : tensor<16x2048x1x512xf32>) -> tensor<16x2048x1x512xf32>
%9 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%inserted_slice, %transposed_1 : tensor<16x2048x1x513xbf16>, tensor<2048x2048x1x2xbf16>) outs(%8 : tensor<16x2048x1x512xf32>) {
^bb0(%in: bf16, %in_3: bf16, %out: f32):
%14 = arith.extf %in : bf16 to f32
%15 = arith.extf %in_3 : bf16 to f32
%16 = arith.mulf %14, %15 : f32
%17 = arith.addf %16, %out : f32
linalg.yield %17 : f32
} -> tensor<16x2048x1x512xf32>
%10 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<16x2048x1x512xf32>) outs(%4 : tensor<16x2048x1x512xbf16>) {
^bb0(%in: f32, %out: bf16):
%14 = arith.truncf %in : f32 to bf16
linalg.yield %14 : bf16
} -> tensor<16x2048x1x512xbf16>
%11 = tensor.empty() : tensor<16x512x1x2048xbf16>
%transposed_2 = linalg.transpose ins(%10 : tensor<16x2048x1x512xbf16>) outs(%11 : tensor<16x512x1x2048xbf16>) permutation = [0, 3, 2, 1]
%12 = hal.tensor.alias wait(%arg3) => %transposed_2 : tensor<16x512x1x2048xbf16> to %arg0 : !hal.buffer_view
%13 = hal.tensor.barrier join(%12 : tensor<16x512x1x2048xbf16>) => %arg4 : !hal.fence
util.return
}
util.func public @main(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) attributes {iree.abi.stub} {
%0 = util.null : !hal.fence
%c-1_i32 = arith.constant -1 : i32
%c0 = arith.constant 0 : index
%device_0 = hal.devices.get %c0 : !hal.device
%fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
util.call @main$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> ()
%status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
util.return
}
}
And the IRs after dispatch generation: https://gist.github.com/yzhang93/642c4ce5443e5e3b49f7f73409811246
The performance of BOO is 551us vs Fusilli 1031us for the following IREE input.
BOO:
Fusilli:
And the IRs after dispatch generation: https://gist.github.com/yzhang93/642c4ce5443e5e3b49f7f73409811246