-
Notifications
You must be signed in to change notification settings - Fork 955
Expand file tree
/
Copy pathConvolution.cpp
More file actions
768 lines (678 loc) · 22.1 KB
/
Convolution.cpp
File metadata and controls
768 lines (678 loc) · 22.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Convolution.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
namespace vkcompute {
enum class Conv2dMethod : uint8_t {
Depthwise,
Pointwise,
SlidingWindow,
Transposed,
};
void resize_conv2d_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
const ValueRef out = args.at(0).refs.at(0);
const ValueRef self = args.at(1).refs.at(0);
size_t ndim = graph->dim_of(self);
std::vector<int64_t> new_out_sizes(ndim);
const bool transposed = graph->get_bool(extra_args.at(4));
std::vector<int64_t> self_sizes = graph->sizes_of(self);
// Batch, Channel
if (ndim == 4) {
new_out_sizes.at(ndim - 4) = self_sizes.at(ndim - 4);
}
TensorRefPtr weight_ref = graph->get_tref(extra_args.at(0));
const auto& weight_sizes = weight_ref->sizes;
new_out_sizes.at(ndim - 3) =
transposed ? weight_sizes.at(ndim - 3) : weight_sizes.at(ndim - 4);
// Height, Width
const auto& new_out_sizes_hw = calc_out_sizes_hw(
*graph,
self_sizes,
extra_args.at(0),
/*kernel_size_only = */ false,
{extra_args.at(1), extra_args.at(2), extra_args.at(3), extra_args.at(5)},
transposed);
new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0);
new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);
graph->virtual_resize(out, new_out_sizes);
}
void resize_conv1d_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
const ValueRef out = args.at(0).refs.at(0);
const ValueRef self = args.at(1).refs.at(0);
TensorRefPtr weight_ref = graph->get_tref(extra_args.at(0));
const int64_t stride_size = graph->get_int_list(extra_args.at(1))->at(0);
const int64_t padding_size = graph->get_int_list(extra_args.at(2))->at(0);
const int64_t dilation_size = graph->get_int_list(extra_args.at(3))->at(0);
const std::vector<int64_t>& weight_sizes = weight_ref->sizes;
const std::vector<int64_t> in_sizes = graph->sizes_of(self);
const size_t ndim = in_sizes.size();
std::vector<int64_t> new_out_sizes(ndim);
const int64_t kernel_size = weight_sizes.at(2);
const int64_t in_length = in_sizes.at(2);
new_out_sizes.at(0) = in_sizes.at(0);
new_out_sizes.at(1) = weight_sizes.at(0);
new_out_sizes.at(2) = calc_out_size(
in_length, kernel_size, stride_size, padding_size, dilation_size, false);
graph->virtual_resize(out, new_out_sizes);
}
ValueRef prepack_biases(
ComputeGraph& graph,
const ValueRef vref,
const ValueRef weight,
const bool transposed,
const utils::StorageType storage_type,
const utils::GPUMemoryLayout memory_layout) {
auto sizes = graph.sizes_of(weight);
const int64_t out_channels = transposed ? sizes.at(1) : sizes.at(0);
ValueRef v = graph.add_tensor(
{out_channels}, graph.dtype_of(weight), storage_type, memory_layout);
vkapi::ShaderInfo shader =
get_nchw_to_tensor_shader(graph, v, graph.get_staging_dtype_for(weight));
graph.prepack_nodes().emplace_back(new PrepackNode(
graph,
shader,
graph.create_global_wg_size(v),
graph.create_local_wg_size(v),
vref,
v,
{},
// Specialization constants
{graph.hashed_layout_of(v)},
{graph.sizes_pc_of(v)}));
return v;
}
vkapi::ShaderInfo get_conv2d_shader(
ComputeGraph& graph,
const ValueRef out,
const bool prepack_weights,
const Conv2dMethod method,
const ValueRef weight,
const bool clamp_out = false,
const bool stride_equals_dilation = false,
const bool stride_1_padding_0 = false) {
std::string kernel_name;
kernel_name.reserve(kShaderNameReserve);
switch (method) {
case Conv2dMethod::Depthwise:
kernel_name = "conv2d_dw";
break;
case Conv2dMethod::Pointwise:
if (prepack_weights) {
kernel_name = "conv2d";
} else {
kernel_name = stride_1_padding_0 ? "conv2d_pw_s1p0" : "conv2d_pw";
}
break;
case Conv2dMethod::SlidingWindow:
kernel_name = "conv2d";
break;
case Conv2dMethod::Transposed:
kernel_name = "conv_transpose2d";
break;
}
if (prepack_weights) {
kernel_name += "_prepack_weights";
} else if (clamp_out) {
kernel_name += "_clamp";
}
add_dtype_suffix(kernel_name, graph.dtype_of(out));
if (prepack_weights) {
add_dtype_suffix(kernel_name, graph.get_staging_dtype_for(weight));
}
return VK_KERNEL_FROM_STR(kernel_name);
}
std::vector<int64_t> get_final_sizes(
const std::vector<int64_t>& original_sizes,
const Conv2dMethod method) {
int64_t batch_padded = utils::align_up_4(utils::val_at(-4, original_sizes));
int64_t channels_padded =
utils::align_up_4(utils::val_at(-3, original_sizes));
int64_t height = utils::val_at(-2, original_sizes);
int64_t width = utils::val_at(-1, original_sizes);
switch (method) {
case Conv2dMethod::Depthwise:
return std::vector<int64_t>{4, batch_padded / 4, height * width};
case Conv2dMethod::Pointwise:
case Conv2dMethod::SlidingWindow:
return std::vector<int64_t>{
4, batch_padded * height / 4, channels_padded * width};
case Conv2dMethod::Transposed:
return std::vector<int64_t>{
4, channels_padded * height / 4, batch_padded * width};
}
}
ValueRef prepack_weights(
ComputeGraph& graph,
const ValueRef vref,
const Conv2dMethod method) {
const auto original_sizes = graph.sizes_of(vref);
const auto final_sizes = get_final_sizes(original_sizes, method);
ValueRef v = graph.add_tensor(
final_sizes,
graph.dtype_of(vref),
utils::kTexture2D,
utils::kChannelsPacked);
vkapi::ShaderInfo shader =
get_conv2d_shader(graph, v, /*prepack_weights = */ true, method, vref);
const auto original_sizes_pc =
utils::make_ivec4(original_sizes, /*reverse = */ true);
graph.prepack_nodes().emplace_back(new PrepackNode(
graph,
shader,
graph.create_global_wg_size(v),
graph.create_local_wg_size(v),
vref,
v,
{},
// Specialization constants
{graph.packed_dim_of(v)},
{graph.sizes_pc_of(v),
PushConstantDataInfo(&original_sizes_pc, sizeof(original_sizes_pc))}));
return v;
}
void check_conv_args(
ComputeGraph& graph,
const ValueRef in,
const ValueRef out) {
VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kChannelsDim);
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kChannelsDim);
}
struct Conv2dParams final {
utils::ivec2 overlay_region;
int in_group_size;
};
struct OutputParams final {
float out_min;
float out_max;
};
Conv2dParams create_conv2d_params(
ComputeGraph& graph,
const ValueRef weight,
const Kernel2dParams& p,
const bool transposed) {
const auto& overlay_region = utils::make_ivec2({
p.kernel_size[0] + (p.kernel_size[0] - 1) * (p.dilation[0] - 1),
p.kernel_size[1] + (p.kernel_size[1] - 1) * (p.dilation[1] - 1),
});
const auto weight_sizes = graph.sizes_of(weight);
const int32_t in_group_size = utils::safe_downcast<int32_t>(
utils::align_up_4(transposed ? weight_sizes.at(0) : weight_sizes.at(1)));
return {overlay_region, in_group_size};
}
void check_conv2d_params(const Kernel2dParams& p, const bool transposed) {
if (transposed) {
if (p.dilation[0] > 1 || p.dilation[1] > 1) {
VK_THROW(
"aten.convolution.default: transposed = true, dilation > 1 is not supported yet!");
}
}
}
Conv2dMethod get_conv2d_method(
ComputeGraph& graph,
const ValueRef weight,
const int64_t groups,
const bool transposed) {
const auto weight_sizes = graph.sizes_of(weight);
if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
return Conv2dMethod::Depthwise;
}
if (transposed) {
return Conv2dMethod::Transposed;
}
if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) {
return Conv2dMethod::Pointwise;
}
return Conv2dMethod::SlidingWindow;
}
utils::uvec3 create_conv2d_global_wg_size(
ComputeGraph& graph,
const Conv2dMethod method,
const ValueRef out,
const ValueRef weight_data,
const bool stride_equals_dilation) {
if (method == Conv2dMethod::Pointwise) {
const utils::uvec3 image_extents = graph.logical_limits_of(out);
return {
utils::div_up(image_extents[0u], 1u),
utils::div_up(image_extents[1u], 4u),
image_extents[2u]};
} else {
return graph.create_global_wg_size(out);
}
}
// Custom global workgroup size function for conv2d
utils::uvec3 conv2d_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
const ValueRef out = args.at(0).refs.at(0);
const ValueRef weight_data = resize_args.at(0);
// Determine method from shader name
Conv2dMethod method;
if (shader.kernel_name.find("conv2d_pw") != std::string::npos ||
(shader.kernel_name.find("conv2d") != std::string::npos &&
shader.kernel_name.find("conv_transpose2d") == std::string::npos)) {
// Check if it's pointwise by examining weight sizes
const auto& weight_sizes = graph->get_tref(weight_data)->sizes;
if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) {
method = Conv2dMethod::Pointwise;
} else {
method = Conv2dMethod::SlidingWindow;
}
} else if (shader.kernel_name.find("conv_transpose2d") != std::string::npos) {
method = Conv2dMethod::Transposed;
} else {
method = Conv2dMethod::SlidingWindow;
}
// Determine stride_equals_dilation from shader name
bool stride_equals_dilation =
shader.kernel_name.find("_sned") == std::string::npos;
utils::uvec3 wg_size = create_conv2d_global_wg_size(
*graph, method, out, weight_data, stride_equals_dilation);
if (method == Conv2dMethod::Pointwise) {
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
if (shader.kernel_name.find("s1p0") != std::string::npos) {
wg_size[0] *= 4;
}
}
return wg_size;
}
// Custom local workgroup size function for conv2d
utils::uvec3 conv2d_local_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)args;
(void)resize_args;
// Determine method from shader name
Conv2dMethod method;
if (shader.kernel_name.find("conv2d_pw") != std::string::npos ||
(shader.kernel_name.find("conv2d") != std::string::npos &&
shader.kernel_name.find("conv_transpose2d") == std::string::npos)) {
method = Conv2dMethod::Pointwise;
} else {
method = Conv2dMethod::SlidingWindow;
}
if (method == Conv2dMethod::Pointwise) {
uint32_t local_wg_size_y = 1;
if (global_workgroup_size[1] % 8 == 0) {
local_wg_size_y = 8;
} else if (global_workgroup_size[1] % 4 == 0) {
local_wg_size_y = 4;
} else if (global_workgroup_size[1] % 2 == 0) {
local_wg_size_y = 2;
}
return {64 / local_wg_size_y, local_wg_size_y, 1};
} else {
return graph->create_local_wg_size(global_workgroup_size);
}
}
// Custom global workgroup size function for conv1d
utils::uvec3 conv1d_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)resize_args;
const ValueRef out = args.at(0).refs.at(0);
return {// out length
graph->size_at<uint32_t>(-1, out),
// out channels
static_cast<uint32_t>(graph->size_at<int64_t>(-2, out)),
// out batches
utils::div_up_4(graph->size_at<uint32_t>(-3, out))};
}
void add_conv2d_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef weight_data,
const ValueRef bias,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation,
const ValueRef transposed,
const ValueRef output_padding,
const ValueRef groups,
const ValueRef out_min,
const ValueRef out_max,
const ValueRef out,
const bool clamp_out) {
const bool transposed_val = graph.get_bool(transposed);
float out_min_val = 0.0f;
float out_max_val = 0.0f;
if (out_min != kDummyValueRef) {
out_min_val = graph.extract_scalar<float>(out_min);
}
if (out_max != kDummyValueRef) {
out_max_val = graph.extract_scalar<float>(out_max);
}
const int64_t groups_val = graph.get_int(groups);
const Conv2dMethod method =
get_conv2d_method(graph, weight_data, groups_val, transposed_val);
// Use tiled path for all pointwise conv2d
if (method == Conv2dMethod::Pointwise) {
return conv2d_pw_impl(
graph,
in,
weight_data,
bias,
stride,
padding,
out,
transposed_val,
clamp_out,
out_min_val,
out_max_val);
}
if (method == Conv2dMethod::Depthwise) {
return conv2d_dw_impl(
graph,
in,
weight_data,
bias,
stride,
padding,
dilation,
out,
clamp_out,
out_min_val,
out_max_val);
}
ValueRef arg_weight = prepack_weights(graph, weight_data, method);
ValueRef arg_bias = prepack_biases(
graph,
bias,
weight_data,
transposed_val,
/* storage_type = */ utils::kTexture2D,
/* memory_layout = */ utils::kWidthPacked);
const std::vector<int64_t> in_sizes = graph.sizes_of(in);
if (in_sizes.at(0) > 1) {
VK_THROW("conv2d: input batch size > 1 is not supported yet!");
}
check_conv_args(graph, in, out);
Kernel2dParams kernel_params = create_kernel2d_params(
graph,
weight_data,
/*kernel_size_only = */ false,
stride,
padding,
dilation);
Conv2dParams extra_params =
create_conv2d_params(graph, weight_data, kernel_params, transposed_val);
const bool stride_equals_dilation =
(kernel_params.stride[0] == kernel_params.dilation[0] &&
kernel_params.stride[1] == kernel_params.dilation[1]);
const bool stride_1_padding_0 =
(kernel_params.stride[0] == 1 && kernel_params.stride[1] == 1 &&
kernel_params.padding[0] == 0 && kernel_params.padding[1] == 0);
OutputParams out_params = {out_min_val, out_max_val};
check_conv2d_params(kernel_params, transposed_val);
vkapi::ShaderInfo shader = get_conv2d_shader(
graph,
out,
/*prepack_weights = */ false,
method,
weight_data,
clamp_out,
stride_equals_dilation,
stride_1_padding_0);
vkapi::ParamsBindList param_buffers = {
graph.logical_limits_ubo(out),
graph.sizes_ubo(in),
graph.create_params_buffer(kernel_params),
graph.create_params_buffer(extra_params),
graph.create_params_buffer(out_params),
};
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
shader,
conv2d_global_wg_size,
conv2d_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
// Shader params buffers
param_buffers,
// Push Constants
{},
// Specialization Constants
{utils::safe_downcast<int32_t>(groups_val)},
// Resize Args
{weight_data, stride, padding, dilation, transposed, output_padding},
// Resizing Logic
resize_conv2d_node));
}
void add_conv1d_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef weight,
const ValueRef bias,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation,
const ValueRef groups,
const ValueRef out_min,
const ValueRef out_max,
const ValueRef out,
const bool clamp_out) {
ValueRef arg_weight = prepack_standard(
graph,
weight,
graph.storage_type_of(out),
utils::kChannelsPacked,
/* passthrough = */ false,
utils::kOptimizedAxisMap);
ValueRef arg_bias = prepack_biases(
graph,
bias,
weight,
/*transposed = */ false,
/*storage_type = */ utils::kTexture3D,
/*memory_layout = */ utils::kWidthPacked);
float out_min_val = 0.0f;
float out_max_val = 0.0f;
if (out_min != kDummyValueRef) {
out_min_val = graph.extract_scalar<float>(out_min);
}
if (out_max != kDummyValueRef) {
out_max_val = graph.extract_scalar<float>(out_max);
}
const int64_t groups_val = graph.get_int(groups);
const std::vector<int64_t> in_sizes = graph.sizes_of(in);
const std::vector<int64_t> weight_sizes = graph.sizes_of(arg_weight);
const std::vector<int64_t> out_sizes = graph.sizes_of(out);
check_conv_args(graph, in, out);
const int32_t in_channels = in_sizes.at(1);
const int32_t out_channels = weight_sizes.at(0);
const int32_t kernel_size = weight_sizes.at(2);
const int32_t stride_size = graph.get_int_list(stride)->at(0);
const int32_t padding_size = graph.get_int_list(padding)->at(0);
const int32_t dilation_size = graph.get_int_list(dilation)->at(0);
const int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
const int32_t out_group_size =
static_cast<int64_t>(out_channels / groups_val);
Kernel1dParams kernel_params = {
kernel_size,
stride_size,
padding_size,
dilation_size,
in_group_size,
out_group_size};
const OutputParams out_params = {out_min_val, out_max_val};
std::string kernel_name("conv1d");
if (clamp_out) {
kernel_name += "_clamp";
}
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, graph.dtype_of(out));
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
conv1d_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
// Shader params buffers
{
graph.logical_limits_ubo(out),
graph.sizes_ubo(in),
graph.create_params_buffer(kernel_params),
graph.create_params_buffer(out_params),
},
// Push Constants
{},
// Specialization Constants
{graph.hashed_layout_of(out),
graph.hashed_layout_of(in),
graph.hashed_layout_of(arg_weight),
graph.hashed_layout_of(arg_bias)},
// Resize Args
{weight, stride, padding, dilation},
// Resizing Logic
resize_conv1d_node));
}
void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int64_t in_ndim = graph.dim_of(args[0]);
if (in_ndim == 4) {
if (args.size() == 10) {
// ordinary conv2d
return add_conv2d_node(
graph,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7],
args[8],
/*out_min = */ kDummyValueRef,
/*out_max = */ kDummyValueRef,
args[9],
false);
} else {
// conv2d with clamp
return add_conv2d_node(
graph,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7],
args[8],
args[9],
args[10],
args[11],
true);
}
} else {
// Conv1d path
if (graph.packed_dim_of(args[0]) == WHCN::kHeightDim) {
// Height-packed: route to optimized conv1d implementations
const auto weight_sizes = graph.sizes_of(args[1]);
const int64_t groups_val = graph.get_int(args[8]);
const bool is_pointwise = weight_sizes.at(2) == 1;
const bool is_depthwise =
groups_val == weight_sizes.at(0) && weight_sizes.at(1) == 1;
// Build unified 10-arg vector:
// in, weight, bias, stride, padding, dilation, groups,
// output_min, output_max, out
// For non-clamp (args.size() == 10): output_min/max = kDummyValueRef
// For clamp (args.size() == 12): output_min/max from args[9]/args[10]
ValueRef output_min = kDummyValueRef;
ValueRef output_max = kDummyValueRef;
ValueRef out;
if (args.size() == 10) {
out = args[9];
} else {
output_min = args[9];
output_max = args[10];
out = args[11];
}
std::vector<ValueRef> conv1d_args = {
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[8],
output_min,
output_max,
out};
if (is_pointwise) {
VK_GET_OP_FN("et_vk.conv1d_pw.default")(graph, conv1d_args);
} else if (is_depthwise) {
VK_GET_OP_FN("et_vk.conv1d_dw.default")(graph, conv1d_args);
} else {
VK_THROW(
"Height-packed conv1d only supports pointwise (K=1) or "
"depthwise (groups=C)");
}
return;
}
// Existing channels-packed fallback
if (args.size() == 10) {
// ordinary conv1d
return add_conv1d_node(
graph,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[8],
/*out_min = */ kDummyValueRef,
/*out_max = */ kDummyValueRef,
args[9],
false);
} else {
// conv1d with clamp
return add_conv1d_node(
graph,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[8],
args[9],
args[10],
args[11],
true);
}
}
}
REGISTER_OPERATORS {
VK_REGISTER_OP(aten.convolution.default, conv);
VK_REGISTER_OP(conv_with_clamp.default, conv);
VK_REGISTER_OP(et_vk.conv_with_clamp.default, conv);
}
} // namespace vkcompute