Skip to content

Non-unit stride input backward different input and dispatches between BOO and Fusilli #101

@yzhang93

Description

@yzhang93

The performance of BOO is 551us vs Fusilli 1031us for the following IREE input.

BOO:

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
module {
  util.func public @fused_op_layout_cus_ce99c970a1fbf4682883ac5fef7d1b047d72a5c2_16x1x256x2048xbfloat16_16x1x512x2048xbfloat16_2048x1x2x2048xbfloat16$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
    %cst = arith.constant 0.000000e+00 : f32
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %cst_0 = arith.constant 0.000000e+00 : bf16
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<16x1x513x2048xbf16>
    %0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<16x1x256x2048xbf16>
    %1 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<2048x1x2x2048xbf16>
    %2 = tensor.empty() : tensor<2048x1x2x2048xbf16>
    %3 = linalg.fill ins(%cst_0 : bf16) outs(%2 : tensor<2048x1x2x2048xbf16>) -> tensor<2048x1x2x2048xbf16>
    %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<2048x1x2x2048xbf16>) outs(%3 : tensor<2048x1x2x2048xbf16>) {
    ^bb0(%in: bf16, %out: bf16):
      %12 = linalg.index 0 : index
      %13 = linalg.index 2 : index
      %14 = linalg.index 3 : index
      %15 = arith.subi %c1, %13 : index
      %extracted = tensor.extract %1[%12, %c0, %15, %14] : tensor<2048x1x2x2048xbf16>
      linalg.yield %extracted : bf16
    } -> tensor<2048x1x2x2048xbf16>
    %inserted_slice = tensor.insert_slice %0 into %cst_1[0, 0, 1, 0] [16, 1, 256, 2048] [1, 1, 2, 1] : tensor<16x1x256x2048xbf16> into tensor<16x1x513x2048xbf16>
    %5 = tensor.empty() : tensor<16x1x512x2048xf32>
    %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<16x1x512x2048xf32>) -> tensor<16x1x512x2048xf32>
    %7 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%inserted_slice, %4 : tensor<16x1x513x2048xbf16>, tensor<2048x1x2x2048xbf16>) outs(%6 : tensor<16x1x512x2048xf32>) {
    ^bb0(%in: bf16, %in_2: bf16, %out: f32):
      %12 = arith.extf %in : bf16 to f32
      %13 = arith.extf %in_2 : bf16 to f32
      %14 = arith.mulf %12, %13 : f32
      %15 = arith.addf %out, %14 : f32
      linalg.yield %15 : f32
    } -> tensor<16x1x512x2048xf32>
    %8 = tensor.empty() : tensor<16x1x512x2048xbf16>
    %9 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%7 : tensor<16x1x512x2048xf32>) outs(%8 : tensor<16x1x512x2048xbf16>) {
    ^bb0(%in: f32, %out: bf16):
      %12 = arith.truncf %in : f32 to bf16
      linalg.yield %12 : bf16
    } -> tensor<16x1x512x2048xbf16>
    %10 = hal.tensor.barrier join(%9 : tensor<16x1x512x2048xbf16>) => %arg4 : !hal.fence
    %11 = hal.tensor.export %10 : tensor<16x1x512x2048xbf16> -> !hal.buffer_view
    util.return %11 : !hal.buffer_view
  }
  util.func public @fused_op_layout_cus_ce99c970a1fbf4682883ac5fef7d1b047d72a5c2_16x1x256x2048xbfloat16_16x1x512x2048xbfloat16_2048x1x2x2048xbfloat16(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %0 = util.null : !hal.fence
    %c-1_i32 = arith.constant -1 : i32
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    %1 = util.call @fused_op_layout_cus_ce99c970a1fbf4682883ac5fef7d1b047d72a5c2_16x1x256x2048xbfloat16_16x1x512x2048xbfloat16_2048x1x2x2048xbfloat16$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
    util.return %1 : !hal.buffer_view
  }
}

Fusilli:

#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module @module {
  util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
    %cst = arith.constant 0.000000e+00 : bf16
    %cst_0 = arith.constant 0.000000e+00 : f32
    %0 = hal.tensor.import wait(%arg3) => %arg1 : !hal.buffer_view -> tensor<16x256x1x2048xbf16>
    %1 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<2048x2x1x2048xbf16>
    %2 = tensor.empty() : tensor<16x2048x1x256xbf16>
    %transposed = linalg.transpose ins(%0 : tensor<16x256x1x2048xbf16>) outs(%2 : tensor<16x2048x1x256xbf16>) permutation = [0, 3, 2, 1] 
    %3 = tensor.empty() : tensor<2048x2048x1x2xbf16>
    %transposed_1 = linalg.transpose ins(%1 : tensor<2048x2x1x2048xbf16>) outs(%3 : tensor<2048x2048x1x2xbf16>) permutation = [0, 3, 2, 1] 
    %4 = tensor.empty() : tensor<16x2048x1x512xbf16>
    %5 = tensor.empty() : tensor<16x2048x1x513xbf16>
    %6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<16x2048x1x513xbf16>) -> tensor<16x2048x1x513xbf16>
    %inserted_slice = tensor.insert_slice %transposed into %6[0, 0, 0, 1] [16, 2048, 1, 256] [1, 1, 1, 2] : tensor<16x2048x1x256xbf16> into tensor<16x2048x1x513xbf16>
    %7 = tensor.empty() : tensor<16x2048x1x512xf32>
    %8 = linalg.fill ins(%cst_0 : f32) outs(%7 : tensor<16x2048x1x512xf32>) -> tensor<16x2048x1x512xf32>
    %9 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%inserted_slice, %transposed_1 : tensor<16x2048x1x513xbf16>, tensor<2048x2048x1x2xbf16>) outs(%8 : tensor<16x2048x1x512xf32>) {
    ^bb0(%in: bf16, %in_3: bf16, %out: f32):
      %14 = arith.extf %in : bf16 to f32
      %15 = arith.extf %in_3 : bf16 to f32
      %16 = arith.mulf %14, %15 : f32
      %17 = arith.addf %16, %out : f32
      linalg.yield %17 : f32
    } -> tensor<16x2048x1x512xf32>
    %10 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<16x2048x1x512xf32>) outs(%4 : tensor<16x2048x1x512xbf16>) {
    ^bb0(%in: f32, %out: bf16):
      %14 = arith.truncf %in : f32 to bf16
      linalg.yield %14 : bf16
    } -> tensor<16x2048x1x512xbf16>
    %11 = tensor.empty() : tensor<16x512x1x2048xbf16>
    %transposed_2 = linalg.transpose ins(%10 : tensor<16x2048x1x512xbf16>) outs(%11 : tensor<16x512x1x2048xbf16>) permutation = [0, 3, 2, 1] 
    %12 = hal.tensor.alias wait(%arg3) => %transposed_2 : tensor<16x512x1x2048xbf16> to %arg0 : !hal.buffer_view
    %13 = hal.tensor.barrier join(%12 : tensor<16x512x1x2048xbf16>) => %arg4 : !hal.fence
    util.return
  }
  util.func public @main(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) attributes {iree.abi.stub} {
    %0 = util.null : !hal.fence
    %c-1_i32 = arith.constant -1 : i32
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    util.call @main$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> ()
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
    util.return
  }
}

And the IRs after dispatch generation: https://gist.github.com/yzhang93/642c4ce5443e5e3b49f7f73409811246

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions