diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 308718ade7d..615014a3294 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -802,6 +802,47 @@ def check_conv_node(node: torch.fx.Node) -> bool: return True + def pick_conv_storage( + node: torch.fx.Node, + ) -> Tuple[List[utils.TensorRepSet], utils.TensorRepSet]: + x = node.args[0] + assert isinstance(x, torch.fx.Node) + x_shape = x.meta["val"].size() + + # Default: channels-packed texture (conv2d and fallback conv1d) + input_storage = utils.CHANNELS_PACKED_TEXTURE + output_storage = utils.CHANNELS_PACKED_TEXTURE + + if len(x_shape) == 3: + # Conv1d: check if we can use height-packed + weight = node.args[1] + assert isinstance(weight, torch.fx.Node) + w_shape = weight.meta["val"].size() + groups = node.args[8] + + c_in = x_shape[1] + c_out = w_shape[0] + kernel_size = w_shape[2] + + is_pointwise = kernel_size == 1 + is_depthwise = ( + isinstance(groups, int) + and groups == c_in + and c_out == c_in + and w_shape[1] == 1 + ) + if is_pointwise or is_depthwise: + input_storage = utils.HEIGHT_PACKED_TEXTURE + output_storage = utils.HEIGHT_PACKED_TEXTURE + + # Build per-input storage list. The convolution op has variable args: + # aten.convolution.default: input, weight, bias, stride, padding, + # dilation, transposed, output_padding, groups + # et_vk.conv_with_clamp.default: + output_min, output_max + # All args after input are NO_STORAGE (prepacked or non-tensor) + inputs = [input_storage] + [utils.NO_STORAGE] * 10 + return inputs, output_storage + return OpFeatures( inputs_storage=[ utils.CHANNELS_PACKED_TEXTURE, # input @@ -820,6 +861,7 @@ def check_conv_node(node: torch.fx.Node) -> bool: supports_resize=True, supports_prepacking=True, are_node_inputs_supported_fn=check_conv_node, + pick_io_storage_fn=pick_conv_storage, ) diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 077ce285cfc..11e7443f785 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -675,6 +675,56 @@ void conv(ComputeGraph& graph, const std::vector& args) { true); } } else { + // Conv1d path + if (graph.packed_dim_of(args[0]) == WHCN::kHeightDim) { + // Height-packed: route to optimized conv1d implementations + const auto weight_sizes = graph.sizes_of(args[1]); + const int64_t groups_val = graph.get_int(args[8]); + const bool is_pointwise = weight_sizes.at(2) == 1; + const bool is_depthwise = + groups_val == weight_sizes.at(0) && weight_sizes.at(1) == 1; + + // Build unified 10-arg vector: + // in, weight, bias, stride, padding, dilation, groups, + // output_min, output_max, out + // For non-clamp (args.size() == 10): output_min/max = kDummyValueRef + // For clamp (args.size() == 12): output_min/max from args[9]/args[10] + ValueRef output_min = kDummyValueRef; + ValueRef output_max = kDummyValueRef; + ValueRef out; + if (args.size() == 10) { + out = args[9]; + } else { + output_min = args[9]; + output_max = args[10]; + out = args[11]; + } + + std::vector conv1d_args = { + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[8], + output_min, + output_max, + out}; + + if (is_pointwise) { + VK_GET_OP_FN("et_vk.conv1d_pw.default")(graph, conv1d_args); + } else if (is_depthwise) { + VK_GET_OP_FN("et_vk.conv1d_dw.default")(graph, conv1d_args); + } else { + VK_THROW( + "Height-packed conv1d only supports pointwise (K=1) or " + "depthwise (groups=C)"); + } + return; + } + + // Existing channels-packed fallback if (args.size() == 10) { // ordinary conv1d return add_conv1d_node(