// Configure block-level tiling for Triton
auto block_config =
fusion_backend_config->mutable_block_level_fusion_config();
block_config->set_num_warps(1);
block_config->set_num_ctas(1);
block_config->set_num_stages(1);
block_config->set_is_tma_allowed(false);
// Identify the dimension along which the concatenation occurs
int64_t concat_dim = concat->concatenate_dimension();
// Set output tiles. We tile the concatenation dimension to 1,
// and keep other dimensions at full size.
auto* output_tile = block_config->add_output_tiles();
for (int64_t i = 0; i < operand->shape().dimensions_size(); ++i) {
output_tile->add_sizes(operand->shape().dimensions(i));
}
void GetBlockLevelParameters(
const xla::HloComputation *fusion_comp,
xla::gpu::BlockLevelParameters &block_level_parameters) {
std::vector<int64_t> vec;
auto shape = fusion_comp->ComputeProgramShape().result();
block_level_parameters.output_tile_sizes.push_back(vec);
VLOG(3) << "GetBlockLevelParameters fusion_comp: " << fusion_comp->ToString()
<< " shape = " << shape.ToString();
auto ss = shape.IsTuple() ? shape.tuple_shapes()[0] : shape;
for (const auto &s : ss.dimensions()) {
VLOG(3) << "block_level_parameters = " << s;
block_level_parameters.output_tile_sizes[0].push_back(s);
}
}
HloModule a_inference_test_concatenate_10__XlaMustCompile_true_config_proto_16571209933021685656_executor_type_11160318154034397263_.9, is_scheduled=true, entry_computation_layout={(f32[2,4]{1,0}, f32[2,4]{1,0})->f32[4,4]{1,0}}
%triton_concat_operand_fusion_arg0.0 (param: f32[2,4]) -> f32[2,4] {
%param = f32[2,4]{1,0} parameter(0)
ROOT %copy = f32[2,4]{1,0} copy(%param)
}
%triton_concat_operand_fusion_arg1.0 (param.1: f32[2,4]) -> f32[2,4] {
%param.1 = f32[2,4]{1,0} parameter(0)
ROOT %copy.1 = f32[2,4]{1,0} copy(%param.1)
}
%a_inference_test_concatenate_10__XlaMustCompile_true_config_proto_16571209933021685656_executor_type_11160318154034397263_.0 (arg0.0: f32[2,4], arg1.0: f32[2,4]) -> f32[4,4] {
%arg0.0 = f32[2,4]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
%fusion.1 = f32[2,4]{1,0} fusion(%arg0.0), kind=kCustom, calls=%triton_concat_operand_fusion_arg0.0, metadata={op_type="ConcatV2" op_name="concat" source_file="/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py" source_line=1221}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton","block_level_fusion_config":{"num_warps":"1","output_tiles":[{"sizes":["2","4"]}],"num_ctas":1,"num_stages":1,"is_tma_allowed":false}},"force_earliest_schedule":false,"reification_cost":[]}
%arg1.0 = f32[2,4]{1,0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"}
%fusion.2 = f32[2,4]{1,0} fusion(%arg1.0), kind=kCustom, calls=%triton_concat_operand_fusion_arg1.0, metadata={op_type="ConcatV2" op_name="concat" source_file="/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py" source_line=1221}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton","block_level_fusion_config":{"num_warps":"1","output_tiles":[{"sizes":["2","4"]}],"num_ctas":1,"num_stages":1,"is_tma_allowed":false}},"force_earliest_schedule":false,"reification_cost":[]}
ROOT %concatenate.0 = f32[4,4]{1,0} concatenate(%fusion.1, %fusion.2), dimensions={0}, metadata={op_type="ConcatV2" op_name="concat" source_file="/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py" source_line=1221}
}
ENTRY %a_inference_test_concatenate_10__XlaMustCompile_true_config_proto_16571209933021685656_executor_type_11160318154034397263_.9 (arg0.1: f32[2,4], arg1.2: f32[2,4]) -> f32[4,4] {
%arg1.2 = f32[2,4]{1,0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg0.1 = f32[2,4]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
ROOT %fusion_npu_triton_fusion = f32[4,4]{1,0} fusion(%arg0.1, %arg1.2), kind=kCustom, calls=%a_inference_test_concatenate_10__XlaMustCompile_true_config_proto_16571209933021685656_executor_type_11160318154034397263_.0, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton","block_level_fusion_config":{"num_warps":"1","output_tiles":[{"sizes":["4","4"]}],"num_ctas":1,"num_stages":1,"is_tma_allowed":false}},"force_earliest_schedule":false,"reification_cost":[]}
}
Hello,
I encountered an issue with the concatenate op while adapting OpenXLA for an NPU. The error message is: 'Dimension bounds are not divisible by tile size: (d0, d1) -> (d0 - 2, d1), domain: d0 in [2, 3], d1 in [0, 3]'. Here, the tile size is 4 and the dimension bounds are [2, 3]. The test case involves two (2, 4) inputs with axis=0, resulting in a (4, 4) output. I noticed that when my code recursively calls ComputeTiledHloInstructionsImpl, the flat_tiling_parameters are always [4, 4], which corresponds to the output's tile size.
Set operands tile sizes:
Set output tile sizes:
Generated Code: