From f9150db3168a255a03fa2c24b5bd13a5753db992 Mon Sep 17 00:00:00 2001 From: qflen Date: Fri, 24 Apr 2026 03:28:38 +0200 Subject: [PATCH] Handle non-multiple-of-8 spatial dims in depthwise conv2d Metal path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The depthwise Metal kernel dispatches (tc=8, tw=8, th=4)-sized threadgroups across (C/tc, oW/tw, (oH/th)*N) blocks and gates itself in dispatch_conv_2D_gpu with `oS[0] % 8 == 0 && oS[1] % 8 == 0`. Any other output spatial shape falls back to the generic implicit/explicit gemm path, which is 2-5x slower on M-series. Multi-stage vision encoders hit this at every non-mod-8 intermediate feature map. Relax the gate and make the kernel tail-aware: - dispatch_conv_2D_gpu: ceil-div `grid_dims` for oS so tail tiles are launched, and drop the `oS % 8 == 0` gate. - depthwise_conv_2d: ceil-div `n_tgblocks_h` to match the dispatch, and skip the output write when `oh` or `ow` is past `oS`. The load path is already bounds-checked, and out-of-range threads still cooperate in shared-memory loads — only the store needs guarding. Measured on M5 32GB with the issue reproducer (3x3, pad=1, stride=1, depthwise, median of 3 runs): | channels | res | baseline ms | patched ms | speedup | |---------:|:-------|------------:|-----------:|--------:| | 384 | 60x60 | 1.64 | 0.34 | 4.8x | | 768 | 30x30 | 0.94 | 0.28 | 3.4x | | 1536 | 15x15 | 0.61 | 0.23 | 2.7x | Mod-8 shapes are unaffected (within ±2% across runs). Adds a GPU-vs-CPU allclose test in tests/gpu_tests.cpp covering non-mod-8 3x3/5x5 shapes, asymmetric spatial alignment, and stride-2 downsampling. Closes #3324. --- mlx/backend/metal/conv.cpp | 5 ++-- mlx/backend/metal/kernels/conv.metal | 6 ++++- tests/gpu_tests.cpp | 40 ++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 5d032779d3..d9e0711b02 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -962,7 +962,9 @@ void depthwise_conv_2D_gpu( MTL::Size group_dims = MTL::Size(tc, tw, th); MTL::Size grid_dims = MTL::Size( - conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N); + conv_params.C / tc, + (conv_params.oS[1] + tw - 1) / tw, + ((conv_params.oS[0] + th - 1) / th) * N); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -986,7 +988,6 @@ void dispatch_conv_2D_gpu( if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && - conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && conv_params.wt_strides[1] == conv_params.wS[1] && conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 24522bb714..220afc307d 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -206,7 +206,7 @@ template threadgroup T ins[TGH * TGW * TGC]; - const int n_tgblocks_h = params.oS[0] / th; + const int n_tgblocks_h = (params.oS[0] + th - 1) / th; const int n = tid.z / n_tgblocks_h; const int tghid = tid.z % n_tgblocks_h; const int oh = tghid * th + lid.z; @@ -277,6 +277,10 @@ template } threadgroup_barrier(mem_flags::mem_none); + if (oh >= params.oS[0] || ow >= params.oS[1]) { + return; + } + out += n * params.out_strides[0] + oh * params.out_strides[1] + ow * params.out_strides[2]; out[c] = static_cast(o); diff --git a/tests/gpu_tests.cpp b/tests/gpu_tests.cpp index 58cca348e5..007fc60a9d 100644 --- a/tests/gpu_tests.cpp +++ b/tests/gpu_tests.cpp @@ -521,3 +521,43 @@ TEST_CASE("test memory info") { clear_cache(); CHECK_EQ(get_cache_memory(), 0); } + +TEST_CASE("test gpu depthwise conv2d non-mod-8 spatial") { + // Depthwise Metal kernel is gated on C_per_group == 1, C == O, C % 16 == 0, + // kernel <= 7, stride <= 2. Previously the dispatch also required the + // output spatial dims to be multiples of 8; non-multiples fell back to the + // gemm path. These cases exercise the tail-tile dispatch now handled by + // the depthwise kernel itself and verify GPU ≈ CPU. + struct Case { + int N, H, W, C, kH, kW; + std::pair stride; + std::pair padding; + }; + std::vector cases = { + // Matches the issue reproducer resolutions, reduced batch. + {1, 15, 15, 16, 3, 3, {1, 1}, {1, 1}}, + {1, 30, 30, 32, 3, 3, {1, 1}, {1, 1}}, + {1, 60, 60, 16, 3, 3, {1, 1}, {1, 1}}, + // Non-square and non-matching alignment on the two spatial axes. + {1, 17, 30, 16, 5, 5, {1, 1}, {2, 2}}, + {1, 8, 13, 16, 3, 3, {1, 1}, {1, 1}}, + // Stride 2 with non-mod-8 output. + {1, 19, 19, 32, 3, 3, {2, 2}, {1, 1}}, + }; + + auto key = random::key(42); + for (const auto& c : cases) { + auto in = random::normal( + {c.N, c.H, c.W, c.C}, float32, 0.0f, 1.0f, key, Device::cpu); + auto wt = random::normal( + {c.C, c.kH, c.kW, 1}, float32, 0.0f, 1.0f, key, Device::cpu); + eval(in); + eval(wt); + + auto out_cpu = + conv2d(in, wt, c.stride, c.padding, {1, 1}, c.C, Device::cpu); + auto out_gpu = + conv2d(in, wt, c.stride, c.padding, {1, 1}, c.C, Device::gpu); + CHECK(allclose(out_cpu, out_gpu, 1e-4, 1e-4).item()); + } +}