diff --git a/tests/validation/runtime/experimental/low_level/CpuGemmAssemblyDispatch.cpp b/tests/validation/runtime/experimental/low_level/CpuGemmAssemblyDispatch.cpp index 031f7abd8b..f0123fd975 100644 --- a/tests/validation/runtime/experimental/low_level/CpuGemmAssemblyDispatch.cpp +++ b/tests/validation/runtime/experimental/low_level/CpuGemmAssemblyDispatch.cpp @@ -57,9 +57,17 @@ constexpr float tolerance_num = 0.07f; /**< Tolerance number for FP16 data types #endif /* ARM_COMPUTE_ENABLE_FP16 */ #ifdef ARM_COMPUTE_ENABLE_BF16 const AbsoluteTolerance abs_tolerance_bf16( - 0.02f); /**< Absolute tolerance value for comparing reference's output against implementation's output for BF16 data types */ + 0.02f); /**< Absolute tolerance value for comparing reference's output against implementation's output for BF16 data types + We have a large absolute error tolerance for bf16 because even though we're computing with bf16 precision in + the reference implementation, the actual implementation might still be choosing fp32 implementation due to + performance reasons. This might particularly happen in small shapes as the conversion of fp32 input to bf16 + isn't worth it. We don't apply this large absolute tolerance to tests with actual bf16 inputs because we + also do the calculation in bf16 arithmetic in the reference implementation. Therefore, we do not expect large + differences in reference vs. optimized runs. + */ const RelativeTolerance rel_tolerance_bf16( 0.02f); /**< Relative tolerance value for comparing reference's output against implementation's output for BF16 data types */ +constexpr float tolerance_num_bf16 = 1e-5f; #endif /* ARM_COMPUTE_ENABLE_BF16 */ /** CNN data types */ const auto CNNDataTypes = make("DataType", @@ -552,7 +560,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, if(CPUInfo::get().has_bf16()) { // Validate output - validate(Accessor(_target), _reference, rel_tolerance_bf16); + validate(Accessor(_target), _reference, rel_tolerance_bf16, tolerance_num_bf16); } else {