@@ -23,6 +23,17 @@ namespace subgroup2
2323namespace impl
2424{
2525
26+ // forward declarations
27+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
28+ struct inclusive_scan;
29+
30+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
31+ struct exclusive_scan;
32+
33+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
34+ struct reduction;
35+
36+
2637// BinOp needed to specialize native
2738template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
2839struct inclusive_scan
@@ -31,7 +42,7 @@ struct inclusive_scan
3142 using scalar_t = typename Params::scalar_t;
3243 using binop_t = typename Params::binop_t;
3344 // assert binop_t == BinOp
34- using exclusive_scan_op_t = subgroup::impl:: exclusive_scan<binop_t, native>;
45+ using exclusive_scan_op_t = exclusive_scan<Params, binop_t, 1 , native>;
3546
3647 // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
3748
@@ -43,7 +54,7 @@ struct inclusive_scan
4354 [unroll]
4455 for (uint32_t i = 1 ; i < ItemsPerInvocation; i++)
4556 retval[i] = binop (retval[i-1 ], value[i]);
46-
57+
4758 exclusive_scan_op_t op;
4859 scalar_t exclusive = op (retval[ItemsPerInvocation-1 ]);
4960
@@ -60,7 +71,7 @@ struct exclusive_scan
6071 using type_t = typename Params::type_t;
6172 using scalar_t = typename Params::scalar_t;
6273 using binop_t = typename Params::binop_t;
63- using inclusive_scan_op_t = subgroup2::impl:: inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;
74+ using inclusive_scan_op_t = inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;
6475
6576 // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
6677
@@ -86,7 +97,7 @@ struct reduction
8697 using type_t = typename Params::type_t;
8798 using scalar_t = typename Params::scalar_t;
8899 using binop_t = typename Params::binop_t;
89- using op_t = subgroup::impl:: reduction<binop_t, native>;
100+ using op_t = reduction<Params, binop_t, 1 , native>;
90101
91102 // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
92103
@@ -142,25 +153,25 @@ struct inclusive_scan<Params, BinOp, 1, false>
142153 // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
143154 // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
144155
145- type_t operator ()(type_t value)
156+ scalar_t operator ()(scalar_t value)
146157 {
147158 return __call (value);
148159 }
149160
150- static type_t __call (type_t value)
161+ static scalar_t __call (scalar_t value)
151162 {
152163 binop_t op;
153164 const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID ();
154-
155- type_t rhs = glsl::subgroupShuffleUp<type_t >(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
165+
166+ scalar_t rhs = glsl::subgroupShuffleUp<scalar_t >(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
156167 value = op (value, hlsl::mix (rhs, binop_t::identity, subgroupInvocation < 1u));
157-
168+
158169 const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
159170 [unroll]
160171 for (uint32_t i = 1 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
161172 {
162- const uint32_t step = i * 2 ;
163- rhs = glsl::subgroupShuffleUp<type_t >(value, step);
173+ const uint32_t step = 1u << i ;
174+ rhs = glsl::subgroupShuffleUp<scalar_t >(value, step);
164175 value = op (value, hlsl::mix (rhs, binop_t::identity, subgroupInvocation < step));
165176 }
166177 return value;
@@ -174,13 +185,13 @@ struct exclusive_scan<Params, BinOp, 1, false>
174185 using scalar_t = typename Params::scalar_t;
175186 using binop_t = typename Params::binop_t;
176187
177- type_t operator ()(NBL_CONST_REF_ARG (type_t) value)
188+ scalar_t operator ()(scalar_t value)
178189 {
179190 value = inclusive_scan<Params, BinOp, 1 , false >::__call (value);
180191 // can't risk getting short-circuited, need to store to a var
181- type_t left = glsl::subgroupShuffleUp<type_t >(value,1 );
192+ scalar_t left = glsl::subgroupShuffleUp<scalar_t >(value,1 );
182193 // the first invocation doesn't have anything in its left so we set to the binop's identity value for exlusive scan
183- return hlsl:: mix (binop_t::identity, left, bool (glsl::gl_SubgroupInvocationID ())) ;
194+ return bool (glsl::gl_SubgroupInvocationID ()) ? left:binop_t::identity ;
184195 }
185196};
186197
@@ -190,11 +201,21 @@ struct reduction<Params, BinOp, 1, false>
190201 using type_t = typename Params::type_t;
191202 using scalar_t = typename Params::scalar_t;
192203 using binop_t = typename Params::binop_t;
204+ using config_t = typename Params::config_t;
193205
194- scalar_t operator ()(NBL_CONST_REF_ARG (type_t) value)
206+ // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
207+ // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
208+
209+ scalar_t operator ()(scalar_t value)
195210 {
196- // take the last subgroup invocation's value for the reduction
197- return subgroup::BroadcastLast<type_t>(inclusive_scan<Params, BinOp, 1 , false >::__call (value));
211+ binop_t op;
212+
213+ const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
214+ [unroll]
215+ for (uint32_t i = 0 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
216+ value = op (glsl::subgroupShuffleXor<scalar_t>(value,0x1u<<i),value);
217+
218+ return value;
198219 }
199220};
200221
0 commit comments