1- // Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
1+ // Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
22// This file is part of the "Nabla Engine".
33// For conditions of distribution and use, see copyright notice in nabla.h
44#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
55#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
66
7+ // #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
8+ // #include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"
9+
10+ // #include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
11+
12+ // #include "nbl/builtin/hlsl/functional.hlsl"
13+
714#include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
815
916namespace nbl
@@ -16,12 +23,12 @@ namespace subgroup2
1623namespace impl
1724{
1825
19- template<class Binop, typename T , uint32_t ItemsPerInvocation, bool native>
26+ template<class Params , uint32_t ItemsPerInvocation, bool native>
2027struct inclusive_scan
2128{
22- using type_t = T ;
23- using scalar_t = typename Binop::type_t ;
24- using binop_t = Binop ;
29+ using type_t = typename Params::type_t ;
30+ using scalar_t = typename Params::scalar_t ;
31+ using binop_t = typename Params::binop_t ;
2532 using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
2633
2734 // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
@@ -31,27 +38,27 @@ struct inclusive_scan
3138 binop_t binop;
3239 type_t retval;
3340 retval[0 ] = value[0 ];
34- // [unroll(ItemsPerInvocation-1) ]
41+ [unroll]
3542 for (uint32_t i = 1 ; i < ItemsPerInvocation; i++)
3643 retval[i] = binop (retval[i-1 ], value[i]);
3744
3845 exclusive_scan_op_t op;
3946 scalar_t exclusive = op (retval[ItemsPerInvocation-1 ]);
4047
41- // [unroll(ItemsPerInvocation) ]
48+ [unroll]
4249 for (uint32_t i = 0 ; i < ItemsPerInvocation; i++)
4350 retval[i] = binop (retval[i], exclusive);
4451 return retval;
4552 }
4653};
4754
48- template<class Binop, typename T , uint32_t ItemsPerInvocation, bool native>
55+ template<class Params , uint32_t ItemsPerInvocation, bool native>
4956struct exclusive_scan
5057{
51- using type_t = T ;
52- using scalar_t = typename Binop::type_t ;
53- using binop_t = Binop ;
54- using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_t, T , ItemsPerInvocation, native>;
58+ using type_t = typename Params::type_t ;
59+ using scalar_t = typename Params::scalar_t ;
60+ using binop_t = typename Params::binop_t ;
61+ using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params , ItemsPerInvocation, native>;
5562
5663 // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
5764
@@ -64,19 +71,19 @@ struct exclusive_scan
6471
6572 type_t retval;
6673 retval[0 ] = bool (glsl::gl_SubgroupInvocationID ()) ? left[ItemsPerInvocation-1 ] : binop_t::identity;
67- // [unroll(ItemsPerInvocation-1) ]
74+ [unroll]
6875 for (uint32_t i = 1 ; i < ItemsPerInvocation; i++)
6976 retval[i] = value[i-1 ];
7077 return retval;
7178 }
7279};
7380
74- template<class Binop, typename T , uint32_t ItemsPerInvocation, bool native>
81+ template<class Params , uint32_t ItemsPerInvocation, bool native>
7582struct reduction
7683{
77- using type_t = T; // TODO? assert scalar_type<T> == scalar_t
78- using scalar_t = typename Binop::type_t ;
79- using binop_t = Binop ;
84+ using type_t = typename Params::type_t;
85+ using scalar_t = typename Params::scalar_t ;
86+ using binop_t = typename Params::binop_t ;
8087 using op_t = subgroup::impl::reduction<binop_t, native>;
8188
8289 // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
@@ -86,49 +93,81 @@ struct reduction
8693 binop_t binop;
8794 op_t op;
8895 scalar_t retval = value[0 ];
89- // [unroll(ItemsPerInvocation-1) ]
96+ [unroll]
9097 for (uint32_t i = 1 ; i < ItemsPerInvocation; i++)
9198 retval = binop (retval, value[i]);
9299 return op (retval);
93100 }
94101};
95102
96103
97- // spec for N=1 uses subgroup funcs
98- template<class Binop, typename T, bool native>
99- struct inclusive_scan<Binop, T, 1 , native>
104+ // specs for N=1 uses subgroup funcs
105+ // specialize native
106+ // #define SPECIALIZE(NAME,BINOP,SUBGROUP_OP) template<typename T> struct NAME<BINOP<T>,true> \
107+ // { \
108+ // using type_t = T; \
109+ // \
110+ // type_t operator()(NBL_CONST_REF_ARG(type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
111+ // }
112+
113+ // #define SPECIALIZE_ALL(BINOP,SUBGROUP_OP) SPECIALIZE(reduction,BINOP,SUBGROUP_OP); \
114+ // SPECIALIZE(inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
115+ // SPECIALIZE(exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);
116+
117+ // SPECIALIZE_ALL(bit_and,And);
118+ // SPECIALIZE_ALL(bit_or,Or);
119+ // SPECIALIZE_ALL(bit_xor,Xor);
120+
121+ // SPECIALIZE_ALL(plus,Add);
122+ // SPECIALIZE_ALL(multiplies,Mul);
123+
124+ // SPECIALIZE_ALL(minimum,Min);
125+ // SPECIALIZE_ALL(maximum,Max);
126+
127+ // #undef SPECIALIZE_ALL
128+ // #undef SPECIALIZE
129+
130+ // specialize portability
131+ template<class Params, bool native>
132+ struct inclusive_scan<Params, 1 , native>
100133{
101- using binop_t = Binop;
134+ using type_t = typename Params::type_t;
135+ using scalar_t = typename Params::scalar_t;
136+ using binop_t = typename Params::binop_t;
102137 using op_t = subgroup::impl::inclusive_scan<binop_t, native>;
103138 // assert T == scalar type, binop::type == T
104139
105- T operator ()(NBL_CONST_REF_ARG (T ) value)
140+ type_t operator ()(NBL_CONST_REF_ARG (type_t ) value)
106141 {
107142 op_t op;
108143 return op (value);
109144 }
110145};
111146
112- template<class Binop, typename T , bool native>
113- struct exclusive_scan<Binop, T , 1 , native>
147+ template<class Params , bool native>
148+ struct exclusive_scan<Params , 1 , native>
114149{
115- using binop_t = Binop;
150+ using type_t = typename Params::type_t;
151+ using scalar_t = typename Params::scalar_t;
152+ using binop_t = typename Params::binop_t;
116153 using op_t = subgroup::impl::exclusive_scan<binop_t, native>;
117154
118- T operator ()(NBL_CONST_REF_ARG (T ) value)
155+ type_t operator ()(NBL_CONST_REF_ARG (type_t ) value)
119156 {
120157 op_t op;
121158 return op (value);
122159 }
123160};
124161
125- template<class Binop, typename T , bool native>
126- struct reduction<Binop, T , 1 , native>
162+ template<class Params , bool native>
163+ struct reduction<Params , 1 , native>
127164{
128- using binop_t = Binop;
165+ using type_t = typename Params::type_t;
166+ using scalar_t = typename Params::scalar_t;
167+ using binop_t = typename Params::binop_t;
129168 using op_t = subgroup::impl::reduction<binop_t, native>;
130169
131- T operator ()(NBL_CONST_REF_ARG (T ) value)
170+ scalar_t operator ()(NBL_CONST_REF_ARG (type_t ) value)
132171 {
133172 op_t op;
134173 return op (value);
0 commit comments