-
Notifications
You must be signed in to change notification settings - Fork 267
[CK_BUILDER] Add reflection for wmma and bwd weight instances to ck builder reflection #3592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
kabrahamAMD
wants to merge
12
commits into
develop
Choose a base branch
from
kabraham/builder_bwd_reflection
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
2f13c7a
added reflection for conv_fwd_multiple_d_wmma_cshuffle.hpp
kabraham-streamhpc e945dbb
added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle
kabraham-streamhpc ea1e3fe
added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle v3
kabraham-streamhpc b9f197d
added reflection of max_transpose parameters
kabraham-streamhpc 8013a02
fix printing of std optional parameters
kabraham-streamhpc 7ef80d0
fix use of undefined ck::index
kabraham-streamhpc c5da43d
added conv traits for device_grouped_conv_bwd_weight_multiple_d_xdl_c…
kabraham-streamhpc 3e034ba
added xdl two stage instance to reflection
kabraham-streamhpc 9907357
added additional variables
kabraham-streamhpc c84973a
added reflection for grouped_conv_bwd_weight_multiple_d_wmma_cshuffle…
kabraham-streamhpc e8c333b
added reflection for device_grouped_conv_bwd_weigh_wmma_cshuffle_v3
kabraham-streamhpc 71df5c5
added reflection for bwd_weight_wmma_cshuffle
kabraham-streamhpc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
...uilder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
|
|
||
| #include "ck_tile/builder/reflect/conv_traits.hpp" | ||
| #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" | ||
|
|
||
| namespace ck_tile::reflect::conv { | ||
|
|
||
| /// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_Tag | ||
| template <typename Instance> | ||
| requires HasInstanceTraits<Instance> && | ||
| std::same_as<typename InstanceTraits<Instance>::device_kernel_tag, | ||
| DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag> | ||
| constexpr ConvTraits instance_to_conv_traits() | ||
| { | ||
| using InstTraits = InstanceTraits<Instance>; | ||
|
|
||
| return ConvTraits{ | ||
| .spatial_dim = InstTraits::kSpatialDim, | ||
| .direction = conv_direction<Instance>(), | ||
| .layout = bwd_wei_conv_layout<Instance>(), | ||
| .data_type = conv_data_type<typename InstTraits::InDataType>(), | ||
| .input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(), | ||
| .weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(), | ||
| .output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(), | ||
| .conv_specialization = conv_spec<Instance>(), | ||
| .thread_block_size = InstTraits::kBlockSize, | ||
| .tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock), | ||
| .a_tile_transfer = | ||
| conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock), | ||
| .b_tile_transfer = | ||
| conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock), | ||
| .warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(), | ||
| .c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(), | ||
| .pipeline_version = get_pipeline_version<InstTraits>(), | ||
| .pipeline_scheduler = get_pipeline_scheduler<InstTraits>(), | ||
| }; | ||
| } | ||
|
|
||
| } // namespace ck_tile::reflect::conv |
53 changes: 53 additions & 0 deletions
53
...le/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
|
|
||
| #include "ck_tile/builder/reflect/conv_traits.hpp" | ||
| #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" | ||
|
|
||
| namespace ck_tile::reflect::conv { | ||
|
|
||
| /// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag | ||
| template <typename Instance> | ||
| requires HasInstanceTraits<Instance> && | ||
| std::same_as<typename InstanceTraits<Instance>::device_kernel_tag, | ||
| DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag> | ||
| constexpr ConvTraits instance_to_conv_traits() | ||
| { | ||
| using InstTraits = InstanceTraits<Instance>; | ||
|
|
||
| return ConvTraits{ | ||
| .spatial_dim = InstTraits::kSpatialDim, | ||
| .direction = conv_direction<Instance>(), | ||
| .layout = bwd_wei_conv_layout<Instance>(), | ||
| .data_type = conv_data_type<typename InstTraits::InDataType>(), | ||
| .input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(), | ||
| .weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(), | ||
| .output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(), | ||
| .conv_specialization = conv_spec<Instance>(), | ||
| .thread_block_size = InstTraits::kBlockSize, | ||
| .tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock), | ||
| .a_tile_transfer = | ||
| conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock), | ||
| .b_tile_transfer = | ||
| conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock), | ||
| .warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(), | ||
| .c_tile_transfer = | ||
| {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, | ||
| .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, | ||
| .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], | ||
| InstTraits::kCThreadClusterLengths[1], | ||
| InstTraits::kCThreadClusterLengths[2], | ||
| InstTraits::kCThreadClusterLengths[3]}, | ||
| .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, | ||
| .pipeline_version = get_pipeline_version<InstTraits>(), | ||
| .pipeline_scheduler = get_pipeline_scheduler<InstTraits>(), | ||
| }; | ||
| } | ||
|
|
||
| } // namespace ck_tile::reflect::conv |
50 changes: 50 additions & 0 deletions
50
...builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
|
|
||
| #include "ck_tile/builder/reflect/conv_traits.hpp" | ||
| #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" | ||
|
|
||
| namespace ck_tile::reflect::conv { | ||
|
|
||
| /// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_wmma_CShuffle_Tag | ||
| template <typename Instance> | ||
| requires HasInstanceTraits<Instance> && | ||
| std::same_as<typename InstanceTraits<Instance>::device_kernel_tag, | ||
| DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag> | ||
| constexpr ConvTraits instance_to_conv_traits() | ||
| { | ||
| using InstTraits = InstanceTraits<Instance>; | ||
|
|
||
| return ConvTraits{ | ||
| .spatial_dim = InstTraits::kSpatialDim, | ||
| .direction = conv_direction<Instance>(), | ||
| .layout = bwd_wei_conv_layout<Instance>(), | ||
| .data_type = conv_data_type<typename InstTraits::InDataType>(), | ||
| .input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(), | ||
| .weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(), | ||
| .output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(), | ||
| .conv_specialization = conv_spec<Instance>(), | ||
| .thread_block_size = InstTraits::kBlockSize, | ||
| .tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock), | ||
| .a_tile_transfer = | ||
| conv_traits_a_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock), | ||
| .b_tile_transfer = | ||
| conv_traits_b_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock), | ||
| .warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(), | ||
| .c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(), | ||
| .pipeline_version = get_pipeline_version<InstTraits>(), | ||
| .pipeline_scheduler = get_pipeline_scheduler<InstTraits>(), | ||
| .max_transpose_transfer_src_scalar_per_vector = | ||
| InstTraits::kTransposeTransferSrcScalarPerVector, | ||
| .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, | ||
| .num_groups_to_merge = InstTraits::kNumGroupsToMerge, | ||
| }; | ||
| } | ||
|
|
||
| } // namespace ck_tile::reflect::conv |
57 changes: 57 additions & 0 deletions
57
...ile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
|
|
||
| #include "ck_tile/builder/reflect/conv_traits.hpp" | ||
| #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" | ||
|
|
||
| namespace ck_tile::reflect::conv { | ||
|
|
||
| /// @brief Tag dispatch implementation for DeviceGroupedConvBwdTwoStage_Xdl_CShuffle_Tag | ||
| template <typename Instance> | ||
| requires HasInstanceTraits<Instance> && | ||
| std::same_as<typename InstanceTraits<Instance>::device_kernel_tag, | ||
| DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag> | ||
| constexpr ConvTraits instance_to_conv_traits() | ||
| { | ||
| using InstTraits = InstanceTraits<Instance>; | ||
|
|
||
| return ConvTraits{ | ||
| .spatial_dim = InstTraits::kSpatialDim, | ||
| .direction = conv_direction<Instance>(), | ||
| .layout = bwd_wei_conv_layout<Instance>(), | ||
| .data_type = conv_data_type<typename InstTraits::InDataType>(), | ||
| .input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(), | ||
| .weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(), | ||
| .output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(), | ||
| .conv_specialization = conv_spec<Instance>(), | ||
| .thread_block_size = InstTraits::kBlockSize, | ||
| .tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock), | ||
| .a_tile_transfer = | ||
| conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock), | ||
| .b_tile_transfer = | ||
| conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock), | ||
| .warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(), | ||
| .c_tile_transfer = | ||
| {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, | ||
| .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, | ||
| .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], | ||
| InstTraits::kCThreadClusterLengths[1], | ||
| InstTraits::kCThreadClusterLengths[2], | ||
| InstTraits::kCThreadClusterLengths[3]}, | ||
| .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, | ||
| .pipeline_version = get_pipeline_version<InstTraits>(), | ||
| .pipeline_scheduler = get_pipeline_scheduler<InstTraits>(), | ||
| .max_transpose_transfer_src_scalar_per_vector = | ||
| InstTraits::kTransposeTransferSrcScalarPerVector, | ||
| .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, | ||
| .num_groups_to_merge = InstTraits::kNumGroupsToMerge, | ||
| }; | ||
| } | ||
|
|
||
| } // namespace ck_tile::reflect::conv |
48 changes: 48 additions & 0 deletions
48
...lude/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
|
|
||
| #include "ck_tile/builder/reflect/conv_traits.hpp" | ||
| #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" | ||
|
|
||
| namespace ck_tile::reflect::conv { | ||
|
|
||
| /// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag | ||
| template <typename Instance> | ||
| requires HasInstanceTraits<Instance> && | ||
| std::same_as<typename InstanceTraits<Instance>::device_kernel_tag, | ||
| DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag> | ||
| constexpr ConvTraits instance_to_conv_traits() | ||
| { | ||
| using InstTraits = InstanceTraits<Instance>; | ||
|
|
||
| return ConvTraits{ | ||
| .spatial_dim = InstTraits::kSpatialDim, | ||
| .direction = conv_direction<Instance>(), | ||
| .layout = bwd_wei_conv_layout<Instance>(), | ||
| .data_type = conv_data_type<typename InstTraits::InDataType>(), | ||
| .input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(), | ||
| .weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(), | ||
| .output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(), | ||
| .conv_specialization = conv_spec<Instance>(), | ||
| .thread_block_size = InstTraits::kBlockSize, | ||
| .tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock), | ||
| .a_tile_transfer = | ||
| conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock), | ||
| .b_tile_transfer = | ||
| conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock), | ||
| .warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(), | ||
| .c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(), | ||
| .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, | ||
| .pipeline_version = get_pipeline_version<InstTraits>(), | ||
| .pipeline_scheduler = get_pipeline_scheduler<InstTraits>(), | ||
|
|
||
| }; | ||
| } | ||
|
|
||
| } // namespace ck_tile::reflect::conv |
50 changes: 50 additions & 0 deletions
50
...e/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
|
|
||
| #include "ck_tile/builder/reflect/conv_traits.hpp" | ||
| #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" | ||
|
|
||
| namespace ck_tile::reflect::conv { | ||
|
|
||
| /// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag | ||
| template <typename Instance> | ||
| requires HasInstanceTraits<Instance> && | ||
| std::same_as<typename InstanceTraits<Instance>::device_kernel_tag, | ||
| DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag> | ||
| constexpr ConvTraits instance_to_conv_traits() | ||
| { | ||
| using InstTraits = InstanceTraits<Instance>; | ||
|
|
||
| return ConvTraits{ | ||
| .spatial_dim = InstTraits::kSpatialDim, | ||
| .direction = conv_direction<Instance>(), | ||
| .layout = bwd_wei_conv_layout<Instance>(), | ||
| .data_type = conv_data_type<typename InstTraits::InDataType>(), | ||
| .input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(), | ||
| .weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(), | ||
| .output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(), | ||
| .conv_specialization = conv_spec<Instance>(), | ||
| .thread_block_size = InstTraits::kBlockSize, | ||
| .tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock), | ||
| .a_tile_transfer = | ||
| conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock), | ||
| .b_tile_transfer = | ||
| conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock), | ||
| .warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(), | ||
| .c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(), | ||
| .pipeline_version = get_pipeline_version<InstTraits>(), | ||
| .pipeline_scheduler = get_pipeline_scheduler<InstTraits>(), | ||
| .max_transpose_transfer_src_scalar_per_vector = | ||
| InstTraits::kMaxTransposeTransferSrcScalarPerVector, | ||
| .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, | ||
|
|
||
| }; | ||
| } | ||
|
|
||
| } // namespace ck_tile::reflect::conv |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really good, and you should lead your PR description with this change to ConvTraits (the "what" of the PR), as well as why we are making these optional now (the "why"). One question I have is where we should use std::optional versus using std::variant.
That's the design discussion we should focus on: how should ConvTraits be generalized for backward weights. This PR should update code comments and our
relect/README.mdfile so that everyone understands this important generalization.