Skip to content

Commit e653cad

Browse files
developer0hyeclaude
andcommitted
perf: optimize CPU deform_conv2d forward pass
Three changes to the CPU deformable convolution forward kernel: 1. Replace at::zeros with at::empty for columns and out_buf buffers. The deformable_im2col_kernel writes every element of the columns buffer, and out_buf is fully written by addmm_, so zero-initialization is wasted work. 2. Use addmm_ with beta=0 instead of the default beta=1. This avoids accumulating into uninitialized memory while preserving in-place operation (no extra allocation unlike at::mm). 3. Parallelize deformable_im2col_kernel with at::parallel_for. The im2col loop was the only single-threaded phase in the forward pass (GEMM is already parallelized by BLAS). Each loop iteration writes to a non-overlapping region of the columns buffer, so parallelization is safe. Benchmark results on Apple M2 (CPU, float32): Config Before (ms) After (ms) Change small-b1 9.76 2.44 -75% small-b8 91.77 33.88 -63% medium-b1 216.70 75.80 -65% medium-b8 1152.09 650.00 -44% large-b1 348.86 302.70 -13% large-b4 1342.75 1289.96 -4% Signed-off-by: Yonghye Kwon <developer.0hye@gmail.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Yonghye Kwon <developer.0hye@gmail.com>
1 parent 336d36e commit e653cad

1 file changed

Lines changed: 51 additions & 49 deletions

File tree

torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
6969

7070
#include <ATen/ATen.h>
71+
#include <ATen/Parallel.h>
7172
#include <torch/library.h>
7273

7374
namespace vision {
@@ -139,58 +140,60 @@ void deformable_im2col_kernel(
139140
int out_w,
140141
bool use_mask,
141142
scalar_t* columns) {
142-
for (int index = 0; index != n; ++index) {
143-
const int out_x = index % out_w;
144-
const int out_y = (index / out_w) % out_h;
145-
const int out_b = (index / (out_w * out_h)) % batch_sz;
146-
const int in_c = index / (out_w * out_h * batch_sz);
147-
const int out_c = in_c * weight_h * weight_w;
143+
at::parallel_for(0, n, 0, [&](int64_t begin, int64_t end) {
144+
for (int64_t index = begin; index != end; ++index) {
145+
const int out_x = index % out_w;
146+
const int out_y = (index / out_w) % out_h;
147+
const int out_b = (index / (out_w * out_h)) % batch_sz;
148+
const int in_c = index / (out_w * out_h * batch_sz);
149+
const int out_c = in_c * weight_h * weight_w;
148150

149-
int c_per_offset_grp = n_in_channels / n_offset_grps;
150-
const int grp_idx = in_c / c_per_offset_grp;
151+
int c_per_offset_grp = n_in_channels / n_offset_grps;
152+
const int grp_idx = in_c / c_per_offset_grp;
151153

152-
auto columns_ptr = columns +
153-
(out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) +
154-
out_y * out_w + out_x);
154+
auto columns_ptr = columns +
155+
(out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) +
156+
out_y * out_w + out_x);
155157

156-
auto input_ptr = input +
157-
(out_b * (n_in_channels * height * width) + in_c * (height * width));
158+
auto input_ptr = input +
159+
(out_b * (n_in_channels * height * width) + in_c * (height * width));
158160

159-
auto offset_ptr = offset +
160-
(out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h *
161-
out_w;
161+
auto offset_ptr = offset +
162+
(out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h *
163+
out_w;
162164

163-
auto mask_ptr = mask;
164-
if (use_mask) {
165-
mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w *
166-
out_h * out_w;
167-
}
168-
169-
for (int i = 0; i < weight_h; ++i) {
170-
for (int j = 0; j < weight_w; ++j) {
171-
const int mask_idx = i * weight_w + j;
172-
const int offset_idx = 2 * mask_idx;
165+
auto mask_ptr = mask;
166+
if (use_mask) {
167+
mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w *
168+
out_h * out_w;
169+
}
173170

174-
scalar_t mask_value = 1;
175-
if (use_mask) {
176-
mask_value =
177-
mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x];
171+
for (int i = 0; i < weight_h; ++i) {
172+
for (int j = 0; j < weight_w; ++j) {
173+
const int mask_idx = i * weight_w + j;
174+
const int offset_idx = 2 * mask_idx;
175+
176+
scalar_t mask_value = 1;
177+
if (use_mask) {
178+
mask_value =
179+
mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x];
180+
}
181+
182+
const scalar_t offset_h =
183+
offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];
184+
const scalar_t offset_w = offset_ptr
185+
[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];
186+
const scalar_t y =
187+
(out_y * stride_h - pad_h) + i * dilation_h + offset_h;
188+
const scalar_t x =
189+
(out_x * stride_w - pad_w) + j * dilation_w + offset_w;
190+
*columns_ptr =
191+
mask_value * bilinear_interpolate(input_ptr, height, width, y, x);
192+
columns_ptr += batch_sz * out_h * out_w;
178193
}
179-
180-
const scalar_t offset_h =
181-
offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];
182-
const scalar_t offset_w = offset_ptr
183-
[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];
184-
const scalar_t y =
185-
(out_y * stride_h - pad_h) + i * dilation_h + offset_h;
186-
const scalar_t x =
187-
(out_x * stride_w - pad_w) + j * dilation_w + offset_w;
188-
*columns_ptr =
189-
mask_value * bilinear_interpolate(input_ptr, height, width, y, x);
190-
columns_ptr += batch_sz * out_h * out_w;
191194
}
192195
}
193-
}
196+
});
194197
}
195198

196199
void deformable_im2col(
@@ -1013,7 +1016,7 @@ at::Tensor deform_conv2d_forward_kernel(
10131016
out_w});
10141017
}
10151018

1016-
at::Tensor out_buf = at::zeros(
1019+
at::Tensor out_buf = at::empty(
10171020
{batch_sz / n_parallel_imgs,
10181021
out_channels,
10191022
n_parallel_imgs * out_h,
@@ -1035,7 +1038,7 @@ at::Tensor deform_conv2d_forward_kernel(
10351038
weight_c.size(3)});
10361039

10371040
// Sample points and perform convolution
1038-
auto columns = at::zeros(
1041+
auto columns = at::empty(
10391042
{n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w},
10401043
input_c.options());
10411044
for (int b = 0; b < batch_sz / n_parallel_imgs; b++) {
@@ -1064,10 +1067,9 @@ at::Tensor deform_conv2d_forward_kernel(
10641067
columns = columns.view(
10651068
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
10661069
for (int g = 0; g < n_weight_grps; g++) {
1067-
out_buf[b][g] = out_buf[b][g]
1068-
.flatten(1)
1069-
.addmm_(weight_c[g].flatten(1), columns[g])
1070-
.view_as(out_buf[b][g]);
1070+
out_buf[b][g]
1071+
.flatten(1)
1072+
.addmm_(weight_c[g].flatten(1), columns[g], 0, 1);
10711073
}
10721074
columns =
10731075
columns.view({columns.size(0) * columns.size(1), columns.size(2)});

0 commit comments

Comments
 (0)