44#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
55#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
66
7- // #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
8- // #include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"
7+ #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
8+ #include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"
99
10- // #include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
10+ #include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
11+ #include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
1112
12- // #include "nbl/builtin/hlsl/functional.hlsl"
13+ #include "nbl/builtin/hlsl/functional.hlsl"
1314
14- #include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
15+ // #include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
1516
1617namespace nbl
1718{
@@ -23,12 +24,14 @@ namespace subgroup2
2324namespace impl
2425{
2526
26- template<class Params, uint32_t ItemsPerInvocation, bool native>
27+ // BinOp needed to specialize native
28+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
2729struct inclusive_scan
2830{
2931 using type_t = typename Params::type_t;
3032 using scalar_t = typename Params::scalar_t;
3133 using binop_t = typename Params::binop_t;
34+ // assert binop_t == BinOp
3235 using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
3336
3437 // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
@@ -52,13 +55,13 @@ struct inclusive_scan
5255 }
5356};
5457
55- template<class Params, uint32_t ItemsPerInvocation, bool native>
58+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
5659struct exclusive_scan
5760{
5861 using type_t = typename Params::type_t;
5962 using scalar_t = typename Params::scalar_t;
6063 using binop_t = typename Params::binop_t;
61- using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params, ItemsPerInvocation, native>;
64+ using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;
6265
6366 // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
6467
@@ -78,7 +81,7 @@ struct exclusive_scan
7881 }
7982};
8083
81- template<class Params, uint32_t ItemsPerInvocation, bool native>
84+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
8285struct reduction
8386{
8487 using type_t = typename Params::type_t;
@@ -103,74 +106,98 @@ struct reduction
103106
104107// specs for N=1 uses subgroup funcs
105108// specialize native
106- // #define SPECIALIZE(NAME,BINOP,SUBGROUP_OP) template<typename T> struct NAME<BINOP<T>,true> \
107- // { \
108- // using type_t = T; \
109- // \
110- // type_t operator()(NBL_CONST_REF_ARG(type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
111- // }
109+ #define SPECIALIZE (NAME,BINOP,SUBGROUP_OP) template<class Params, typename T> struct NAME<Params, BINOP<T>, 1 ,true > \
110+ { \
111+ using type_t = T; \
112+ \
113+ type_t operator ()(NBL_CONST_REF_ARG (type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
114+ }
112115
113- // #define SPECIALIZE_ALL(BINOP,SUBGROUP_OP) SPECIALIZE(reduction,BINOP,SUBGROUP_OP); \
114- // SPECIALIZE(inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
115- // SPECIALIZE(exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);
116+ #define SPECIALIZE_ALL (BINOP,SUBGROUP_OP) SPECIALIZE (reduction,BINOP,SUBGROUP_OP); \
117+ SPECIALIZE (inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
118+ SPECIALIZE (exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);
116119
117- // SPECIALIZE_ALL(bit_and,And);
118- // SPECIALIZE_ALL(bit_or,Or);
119- // SPECIALIZE_ALL(bit_xor,Xor);
120+ SPECIALIZE_ALL (bit_and,And);
121+ SPECIALIZE_ALL (bit_or,Or);
122+ SPECIALIZE_ALL (bit_xor,Xor);
120123
121- // SPECIALIZE_ALL(plus,Add);
122- // SPECIALIZE_ALL(multiplies,Mul);
124+ SPECIALIZE_ALL (plus,Add );
125+ SPECIALIZE_ALL (multiplies,Mul);
123126
124- // SPECIALIZE_ALL(minimum,Min);
125- // SPECIALIZE_ALL(maximum,Max);
127+ SPECIALIZE_ALL (minimum,Min );
128+ SPECIALIZE_ALL (maximum,Max );
126129
127- // #undef SPECIALIZE_ALL
128- // #undef SPECIALIZE
130+ #undef SPECIALIZE_ALL
131+ #undef SPECIALIZE
129132
130133// specialize portability
131- template<class Params, bool native >
132- struct inclusive_scan<Params, 1 , native >
134+ template<class Params, class BinOp >
135+ struct inclusive_scan<Params, BinOp, 1 , false >
133136{
134137 using type_t = typename Params::type_t;
135138 using scalar_t = typename Params::scalar_t;
136139 using binop_t = typename Params::binop_t;
137- using op_t = subgroup::impl::inclusive_scan<binop_t, native>;
138140 // assert T == scalar type, binop::type == T
141+ using config_t = typename Params::config_t;
139142
140- type_t operator ()(NBL_CONST_REF_ARG (type_t) value)
143+ // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
144+ // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
145+
146+ type_t operator ()(type_t value)
141147 {
142- op_t op;
143- return op (value);
148+ return __call (value);
149+ }
150+
151+ static type_t __call (type_t value)
152+ {
153+ binop_t op;
154+ const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID ();
155+
156+ type_t rhs = glsl::subgroupShuffleUp<type_t>(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
157+ // TODO waiting on mix intrinsic fix from bxdf branch, value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < 1u));
158+ value = op (value, subgroupInvocation<1u ? binop_t::identity : rhs);
159+
160+ const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
161+ [unroll]
162+ for (uint32_t i = 1 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
163+ {
164+ const uint32_t step = i * 2 ;
165+ rhs = glsl::subgroupShuffleUp<type_t>(value, step);
166+ // TODO value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < step));
167+ value = op (value, subgroupInvocation<step ? binop_t::identity : rhs);
168+ }
169+ return value;
144170 }
145171};
146172
147- template<class Params, bool native >
148- struct exclusive_scan<Params, 1 , native >
173+ template<class Params, class BinOp >
174+ struct exclusive_scan<Params, BinOp, 1 , false >
149175{
150176 using type_t = typename Params::type_t;
151177 using scalar_t = typename Params::scalar_t;
152178 using binop_t = typename Params::binop_t;
153- using op_t = subgroup::impl::exclusive_scan<binop_t, native>;
154179
155180 type_t operator ()(NBL_CONST_REF_ARG (type_t) value)
156181 {
157- op_t op;
158- return op (value);
182+ value = inclusive_scan<Params, BinOp, 1 , false >::__call (value);
183+ // can't risk getting short-circuited, need to store to a var
184+ type_t left = glsl::subgroupShuffleUp<type_t>(value,1 );
185+ // the first invocation doesn't have anything in its left so we set to the binop's identity value for exlusive scan
186+ return bool (glsl::gl_SubgroupInvocationID ()) ? left:binop_t::identity;
159187 }
160188};
161189
162- template<class Params, bool native >
163- struct reduction<Params, 1 , native >
190+ template<class Params, class BinOp >
191+ struct reduction<Params, BinOp, 1 , false >
164192{
165193 using type_t = typename Params::type_t;
166194 using scalar_t = typename Params::scalar_t;
167195 using binop_t = typename Params::binop_t;
168- using op_t = subgroup::impl::reduction<binop_t, native>;
169196
170197 scalar_t operator ()(NBL_CONST_REF_ARG (type_t) value)
171198 {
172- op_t op;
173- return op ( value);
199+ // take the last subgroup invocation's value for the reduction
200+ return subgroup::BroadcastLast<type_t>(inclusive_scan<Params, BinOp, 1 , false >:: __call ( value) );
174201 }
175202};
176203
0 commit comments