44#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
55#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
66
7- #include "nbl/builtin/hlsl/subgroup/arithmetic_portability .hlsl"
7+ #include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl .hlsl"
88
99namespace nbl
1010{
@@ -16,15 +16,15 @@ namespace subgroup2
1616namespace impl
1717{
1818
19- template<class Binop, typename T, bool native>
19+ template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
2020struct inclusive_scan
2121{
2222 using type_t = T;
2323 using scalar_t = typename Binop::type_t;
2424 using binop_t = Binop;
2525 using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
2626
27- NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
27+ // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
2828
2929 type_t operator ()(NBL_CONST_REF_ARG (type_t) value)
3030 {
@@ -45,15 +45,15 @@ struct inclusive_scan
4545 }
4646};
4747
48- template<class Binop, typename T, bool native>
48+ template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
4949struct exclusive_scan
5050{
5151 using type_t = T;
5252 using scalar_t = typename Binop::type_t;
5353 using binop_t = Binop;
54- using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_t, T, native>;
54+ using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_t, T, ItemsPerInvocation, native>;
5555
56- NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
56+ // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
5757
5858 type_t operator ()(type_t value)
5959 {
@@ -71,15 +71,15 @@ struct exclusive_scan
7171 }
7272};
7373
74- template<class Binop, typename T, bool native>
74+ template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
7575struct reduction
7676{
7777 using type_t = T; // TODO? assert scalar_type<T> == scalar_t
7878 using scalar_t = typename Binop::type_t;
7979 using binop_t = Binop;
8080 using op_t = subgroup::impl::reduction<binop_t, native>;
8181
82- NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
82+ // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
8383
8484 scalar_t operator ()(NBL_CONST_REF_ARG (type_t) value)
8585 {
@@ -93,6 +93,48 @@ struct reduction
9393 }
9494};
9595
96+
97+ // spec for N=1 uses subgroup funcs
98+ template<class Binop, typename T, bool native>
99+ struct inclusive_scan<Binop, T, 1 , native>
100+ {
101+ using binop_t = Binop;
102+ using op_t = subgroup::impl::inclusive_scan<binop_t, native>;
103+ // assert T == scalar type, binop::type == T
104+
105+ T operator ()(NBL_CONST_REF_ARG (T) value)
106+ {
107+ op_t op;
108+ return op (value);
109+ }
110+ };
111+
112+ template<class Binop, typename T, bool native>
113+ struct exclusive_scan<Binop, T, 1 , native>
114+ {
115+ using binop_t = Binop;
116+ using op_t = subgroup::impl::exclusive_scan<binop_t, native>;
117+
118+ T operator ()(NBL_CONST_REF_ARG (T) value)
119+ {
120+ op_t op;
121+ return op (value);
122+ }
123+ };
124+
125+ template<class Binop, typename T, bool native>
126+ struct reduction<Binop, T, 1 , native>
127+ {
128+ using binop_t = Binop;
129+ using op_t = subgroup::impl::reduction<binop_t, native>;
130+
131+ T operator ()(NBL_CONST_REF_ARG (T) value)
132+ {
133+ op_t op;
134+ return op (value);
135+ }
136+ };
137+
96138}
97139
98140}
0 commit comments