Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,11 +889,13 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback>
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::DequantizeFloat>>();

// Configure requantization info
const GEMMLowpOutputStageInfo os_info = info.output_stage;

// The dequantize scale is overridden in Fallback::configure via set_dequantize_scale.
// Offsets are taken from AsmGemmInfo (set by CpuGemmDirectConv2d; zero for other callers).
arm_gemm::DequantizeFloat gemm_dequant_info{};
gemm_dequant_info = arm_gemm::DequantizeFloat(d->quantization_info().uniform().scale);
gemm_dequant_info = arm_gemm::DequantizeFloat(
a->quantization_info().uniform().scale * b->quantization_info().uniform().scale,
info.dequant_a_offset,
info.dequant_b_offset);

fallback->configure(a, b, c, d, args, info, gemm_dequant_info);
arm_gemm = std::move(fallback);
Expand Down Expand Up @@ -1020,6 +1022,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
{})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else if (d->data_type() == DataType::F32)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
!(arm_gemm::has_opt_gemm<int8_t, int8_t, float, arm_gemm::DequantizeFloat>(arm_gemm_expected_wf,
args, {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and F32 output");
}
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
Expand Down Expand Up @@ -1130,6 +1139,11 @@ Status CpuGemmAssemblyDispatch::validate(
a->data_type() == DataType::QASYMM8 &&
(d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::F32),
"Only QASYMM8/S32/F32 output supported for QASYMM8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
a->data_type() == DataType::QASYMM8_SIGNED &&
(d->data_type() != DataType::QASYMM8_SIGNED && d->data_type() != DataType::S32 &&
d->data_type() != DataType::F32),
"Only QASYMM8_SIGNED/S32/F32 output supported for QASYMM8_SIGNED input");
arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED;
const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
if (bool(ret) && expected_weight_format != arm_compute::WeightFormat::ANY)
Expand Down
6 changes: 4 additions & 2 deletions src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2025 Arm Limited.
* Copyright (c) 2018-2026 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
Expand Down Expand Up @@ -65,7 +65,9 @@ struct AsmGemmInfo
* @note This flag will be silently ignored (assumed to be false) when the weight_format is a fixed format. Because
* fixed format kernels do not accept weights (B) with any prior transformations
*/
bool transpose_b{false};
bool transpose_b{false};
int32_t dequant_a_offset{0}; // input zero-point for DequantizeFloat path (handled in kernel)
int32_t dequant_b_offset{0}; // weight zero-point for DequantizeFloat path (handled in kernel)
};

/** Assembly kernel glue */
Expand Down