Skip to content

Commit a93a917

Browse files
rascaniclaude
authored andcommitted
Cortex-M backend: Fix conv2d scratch buffer allocation to match CMSIS-NN wrapper dispatch (pytorch#17766)
### Summary Use arm_convolve_wrapper_s8_get_buffer_size instead of arm_convolve_s8_get_buffer_size so the buffer size matches whichever specialized kernel arm_convolve_wrapper_s8 will actually dispatch to at runtime (1x1 fast, 1xN, or general). Also remove the Error::NotFound carve-out that silently proceeded with a null scratch buffer — CMSIS-NN returns ARM_CMSIS_NN_ARG_ERROR when ctx->buf is NULL and a buffer is required, so fail immediately on any allocation error, consistent with the other cortex_m conv ops. Update CMSIS-NN from v7.0.0 to 84303a51fd867c7ddbd23068b7ce930af1b6269d and remove GIT_SHALLOW (incompatible with SHA-based FetchContent pins). Fixes pytorch#18044 cc @digantdesai @SS-JIA @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent e5354d1 commit a93a917

2 files changed

Lines changed: 19 additions & 16 deletions

File tree

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ include(FetchContent)
2323

2424
# CMSIS-NN configuration with dynamic path detection
2525
set(CMSIS_NN_VERSION
26-
"v7.0.0"
26+
"098d54a61e3e04e2b60e9010e871eef036ac5ae3"
2727
CACHE STRING "CMSIS-NN version to download"
2828
)
2929
set(CMSIS_NN_LOCAL_PATH
@@ -45,7 +45,6 @@ else()
4545
cmsis_nn
4646
GIT_REPOSITORY https://github.com/ARM-software/CMSIS-NN.git
4747
GIT_TAG ${CMSIS_NN_VERSION}
48-
GIT_SHALLOW TRUE
4948
)
5049

5150
FetchContent_MakeAvailable(cmsis_nn)

backends/cortex_m/ops/op_quantized_conv2d.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,24 +184,28 @@ Tensor& quantized_conv2d_out(
184184
cmsis_context.buf = nullptr;
185185
cmsis_context.size = 0;
186186

187-
const size_t buffer_bytes = static_cast<size_t>(
188-
arm_convolve_s8_get_buffer_size(&input_dims, &filter_dims));
187+
const int32_t buffer_bytes = arm_convolve_wrapper_s8_get_buffer_size(
188+
&conv_params, &input_dims, &filter_dims, &output_dims);
189+
if (buffer_bytes < 0) {
190+
ET_LOG(
191+
Error, "quantized_conv2d_out: CMSIS-NN buffer size calculation failed");
192+
context.fail(Error::Internal);
193+
return out;
194+
}
189195
if (buffer_bytes > 0) {
190196
auto buffer_or_error =
191197
context.allocate_temp(buffer_bytes, kCortexMMveAlignment);
192198
if (!buffer_or_error.ok()) {
193-
if (buffer_or_error.error() != Error::NotFound) {
194-
ET_LOG(
195-
Error,
196-
"quantized_conv2d_out: failed to allocate scratch buffer (%d)",
197-
static_cast<int>(buffer_or_error.error()));
198-
context.fail(buffer_or_error.error());
199-
return out;
200-
}
201-
} else {
202-
cmsis_context.buf = buffer_or_error.get();
203-
cmsis_context.size = buffer_bytes;
199+
ET_LOG(
200+
Error,
201+
"quantized_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)",
202+
static_cast<int>(buffer_bytes),
203+
static_cast<int>(buffer_or_error.error()));
204+
context.fail(buffer_or_error.error());
205+
return out;
204206
}
207+
cmsis_context.buf = buffer_or_error.get();
208+
cmsis_context.size = buffer_bytes;
205209
}
206210

207211
const arm_cmsis_nn_status status = arm_convolve_wrapper_s8(
@@ -220,7 +224,7 @@ Tensor& quantized_conv2d_out(
220224
if (status != ARM_CMSIS_NN_SUCCESS) {
221225
ET_LOG(
222226
Error,
223-
"quantized_conv2d_out: arm_convolve_s8 failed with status %d",
227+
"quantized_conv2d_out: arm_convolve_wrapper_s8 failed with status %d",
224228
status);
225229
context.fail(Error::Internal);
226230
}

0 commit comments

Comments
 (0)