From a0ca536845ffd540a3bfd62da6e5d96aeaf97c39 Mon Sep 17 00:00:00 2001 From: firestar99 Date: Wed, 22 Oct 2025 18:40:16 +0200 Subject: [PATCH 1/6] ScalarComposite: add `ScalarOrVectorComposite` for subgroup intrinsics --- crates/spirv-std/src/arch/subgroup.rs | 350 +++++++++++------- crates/spirv-std/src/scalar.rs | 10 +- crates/spirv-std/src/scalar_or_vector.rs | 60 ++- crates/spirv-std/src/vector.rs | 8 +- .../subgroup_cluster_size_0_fail.stderr | 4 +- ..._cluster_size_non_power_of_two_fail.stderr | 4 +- 6 files changed, 288 insertions(+), 148 deletions(-) diff --git a/crates/spirv-std/src/arch/subgroup.rs b/crates/spirv-std/src/arch/subgroup.rs index a9690d4190..50b2367899 100644 --- a/crates/spirv-std/src/arch/subgroup.rs +++ b/crates/spirv-std/src/arch/subgroup.rs @@ -1,11 +1,12 @@ -use crate::ScalarOrVector; #[cfg(target_arch = "spirv")] -use crate::arch::barrier; +use crate::ScalarOrVectorTransform; #[cfg(target_arch = "spirv")] -use crate::memory::{Scope, Semantics}; -use crate::{Float, Integer, SignedInteger, UnsignedInteger}; +use crate::arch::{asm, barrier}; #[cfg(target_arch = "spirv")] -use core::arch::asm; +use crate::memory::{Scope, Semantics}; +use crate::{ + Float, Integer, ScalarOrVector, ScalarOrVectorComposite, SignedInteger, UnsignedInteger, +}; #[cfg(target_arch = "spirv")] const SUBGROUP: u32 = Scope::Subgroup as u32; @@ -287,25 +288,34 @@ pub fn subgroup_all_equal(value: T) -> bool { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformBroadcast")] #[inline] -pub unsafe fn subgroup_broadcast(value: T, id: u32) -> T { - let mut result = T::default(); +pub unsafe fn subgroup_broadcast(value: T, id: u32) -> T { + struct Transform { + id: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%id = OpLoad _ {id}", - "%result = OpGroupNonUniformBroadcast _ %subgroup %value %id", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - id = in(reg) &id, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%id = OpLoad _ {id}", + "%result = OpGroupNonUniformBroadcast _ %subgroup %value %id", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + id = in(reg) &self.id, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { id }) } /// Result is the `value` of the invocation identified by the id `id` to all active invocations in the group. @@ -330,24 +340,31 @@ pub unsafe fn subgroup_broadcast(value: T, id: u32) -> T { #[doc(alias = "OpGroupNonUniformBroadcast")] #[inline] pub unsafe fn subgroup_broadcast_const(value: T) -> T { - let mut result = T::default(); + struct Transform; - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%id = OpConstant %u32 {id}", - "%value = OpLoad _ {value}", - "%result = OpGroupNonUniformBroadcast _ %subgroup %value %id", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - id = const ID, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%id = OpConstant %u32 {id}", + "%value = OpLoad _ {value}", + "%result = OpGroupNonUniformBroadcast _ %subgroup %value %id", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + id = const ID, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform::) } /// Result is the `value` of the invocation from the active invocation with the lowest id in the group to all active invocations in the group. @@ -362,23 +379,30 @@ pub unsafe fn subgroup_broadcast_const(value: #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformBroadcastFirst")] #[inline] -pub fn subgroup_broadcast_first(value: T) -> T { - let mut result = T::default(); +pub fn subgroup_broadcast_first(value: T) -> T { + struct Transform; - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%result = OpGroupNonUniformBroadcastFirst _ %subgroup %value", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%result = OpGroupNonUniformBroadcastFirst _ %subgroup %value", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform) } /// Result is a bitfield value combining the `predicate` value from all invocations in the group that execute the same dynamic instance of this instruction. The bit is set to one if the corresponding invocation is active and the `predicate` for that invocation evaluated to true; otherwise, it is set to zero. @@ -637,25 +661,34 @@ pub fn subgroup_ballot_find_msb(value: SubgroupMask) -> u32 { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffle")] #[inline] -pub fn subgroup_shuffle(value: T, id: u32) -> T { - let mut result = T::default(); +pub fn subgroup_shuffle(value: T, id: u32) -> T { + struct Transform { + id: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%id = OpLoad _ {id}", - "%result = OpGroupNonUniformShuffle _ %subgroup %value %id", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - id = in(reg) &id, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%id = OpLoad _ {id}", + "%result = OpGroupNonUniformShuffle _ %subgroup %value %id", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + id = in(reg) &self.id, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { id }) } /// Result is the `value` of the invocation identified by the current invocation’s id within the group xor’ed with Mask. @@ -678,25 +711,34 @@ pub fn subgroup_shuffle(value: T, id: u32) -> T { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffleXor")] #[inline] -pub fn subgroup_shuffle_xor(value: T, mask: u32) -> T { - let mut result = T::default(); +pub fn subgroup_shuffle_xor(value: T, mask: u32) -> T { + struct Transform { + mask: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%mask = OpLoad _ {mask}", - "%result = OpGroupNonUniformShuffleXor _ %subgroup %value %mask", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - mask = in(reg) &mask, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%mask = OpLoad _ {mask}", + "%result = OpGroupNonUniformShuffleXor _ %subgroup %value %mask", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + mask = in(reg) &self.mask, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { mask }) } /// Result is the `value` of the invocation identified by the current invocation’s id within the group - Delta. @@ -719,25 +761,34 @@ pub fn subgroup_shuffle_xor(value: T, mask: u32) -> T { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffleUp")] #[inline] -pub fn subgroup_shuffle_up(value: T, delta: u32) -> T { - let mut result = T::default(); +pub fn subgroup_shuffle_up(value: T, delta: u32) -> T { + struct Transform { + delta: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%delta = OpLoad _ {delta}", - "%result = OpGroupNonUniformShuffleUp _ %subgroup %value %delta", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - delta = in(reg) &delta, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%delta = OpLoad _ {delta}", + "%result = OpGroupNonUniformShuffleUp _ %subgroup %value %delta", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + delta = in(reg) &self.delta, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { delta }) } /// Result is the `value` of the invocation identified by the current invocation’s id within the group + Delta. @@ -760,25 +811,34 @@ pub fn subgroup_shuffle_up(value: T, delta: u32) -> T { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffleDown")] #[inline] -pub fn subgroup_shuffle_down(value: T, delta: u32) -> T { - let mut result = T::default(); +pub fn subgroup_shuffle_down(value: T, delta: u32) -> T { + struct Transform { + delta: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%delta = OpLoad _ {delta}", - "%result = OpGroupNonUniformShuffleDown _ %subgroup %value %delta", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - delta = in(reg) &delta, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%delta = OpLoad _ {delta}", + "%result = OpGroupNonUniformShuffleDown _ %subgroup %value %delta", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + delta = in(reg) &self.delta, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { delta }) } macro_rules! macro_subgroup_op { @@ -1387,25 +1447,34 @@ Requires Capability `GroupNonUniformArithmetic` and `GroupNonUniformClustered`. #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformQuadBroadcast")] #[inline] -pub fn subgroup_quad_broadcast(value: T, index: u32) -> T { - let mut result = T::default(); +pub fn subgroup_quad_broadcast(value: T, index: u32) -> T { + struct Transform { + index: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%index = OpLoad _ {index}", - "%result = OpGroupNonUniformQuadBroadcast _ %subgroup %value %index", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - index = in(reg) &index, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%index = OpLoad _ {index}", + "%result = OpGroupNonUniformQuadBroadcast _ %subgroup %value %index", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + index = in(reg) &self.index, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { index }) } /// Direction is the kind of swap to perform. @@ -1470,23 +1539,30 @@ pub enum QuadDirection { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformQuadSwap")] #[inline] -pub fn subgroup_quad_swap(value: T) -> T { - let mut result = T::default(); +pub fn subgroup_quad_swap(value: T) -> T { + struct Transform; - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%direction = OpConstant %u32 {direction}", - "%value = OpLoad _ {value}", - "%result = OpGroupNonUniformQuadSwap _ %subgroup %value %direction", - "OpStore {result} %result", - subgroup = const SUBGROUP, - direction = const DIRECTION, - value = in(reg) &value, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%direction = OpConstant %u32 {direction}", + "%value = OpLoad _ {value}", + "%result = OpGroupNonUniformQuadSwap _ %subgroup %value %direction", + "OpStore {result} %result", + subgroup = const SUBGROUP, + direction = const DIRECTION, + value = in(reg) &value, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform::) } diff --git a/crates/spirv-std/src/scalar.rs b/crates/spirv-std/src/scalar.rs index b4774ea861..363035dd49 100644 --- a/crates/spirv-std/src/scalar.rs +++ b/crates/spirv-std/src/scalar.rs @@ -1,7 +1,7 @@ //! Traits related to scalars. -use crate::ScalarOrVector; use crate::sealed::Sealed; +use crate::{ScalarOrVector, ScalarOrVectorComposite, ScalarOrVectorTransform}; use core::num::NonZeroUsize; /// Abstract trait representing a SPIR-V scalar type, which includes: @@ -61,7 +61,13 @@ pub unsafe trait Float: num_traits::Float + Number { macro_rules! impl_scalar { (impl Scalar for $ty:ty;) => { impl Sealed for $ty {} - unsafe impl ScalarOrVector for $ty { + impl ScalarOrVectorComposite for $ty { + #[inline] + fn transform(self, f: &mut F) -> Self { + f.transform_scalar(self) + } + } + unsafe impl ScalarOrVector for $ty { type Scalar = Self; const N: NonZeroUsize = NonZeroUsize::new(1).unwrap(); } diff --git a/crates/spirv-std/src/scalar_or_vector.rs b/crates/spirv-std/src/scalar_or_vector.rs index 87b0073241..a8497924e8 100644 --- a/crates/spirv-std/src/scalar_or_vector.rs +++ b/crates/spirv-std/src/scalar_or_vector.rs @@ -1,4 +1,4 @@ -use crate::Scalar; +use crate::{Scalar, Vector}; use core::num::NonZeroUsize; pub(crate) mod sealed { @@ -11,12 +11,64 @@ pub(crate) mod sealed { /// /// # Safety /// Your type must also implement [`Scalar`] or [`Vector`], see their safety sections as well. -/// -/// [`Vector`]: crate::Vector -pub unsafe trait ScalarOrVector: Copy + Default + Send + Sync + 'static { +pub unsafe trait ScalarOrVector: ScalarOrVectorComposite + Default { /// Either the scalar component type of the vector or the scalar itself. type Scalar: Scalar; /// The dimension of the vector, or 1 if it is a scalar const N: NonZeroUsize; } + +/// A `VectorOrScalarComposite` is a type that is either +/// * a [`Scalar`] +/// * a [`Vector`] +/// * an array of `VectorOrScalarComposite` +/// * a struct where all members are `VectorOrScalarComposite` +/// * an enum with a `repr` that is a [`Scalar`] +/// +/// By calling [`Self::transform`] you can visit all the individual [`Scalar`] and [`Vector`] values this composite is +/// build out of and transform them into some other value. This is particularly useful for subgroup intrinsics sending +/// data to other threads. +/// +/// To derive `#[derive(VectorOrScalarComposite)]` on a struct, all members must also implement +/// `VectorOrScalarComposite`. To derive it on an enum, the enum must have `#[repr(N)]` where `N` is an [`Integer`]. +/// Additionally, you must derive `num_enum::FromPrimitive` and `num_enum::ToPrimitive`, which requires the enum to be +/// either exhaustive, implement [`Default`] or a variant of the enum to have the `#[num_enum(default)]` attribute. +/// +/// [`Integer`]: crate::Integer +pub trait ScalarOrVectorComposite: Copy + Send + Sync + 'static { + /// Transform the individual [`Scalar`] and [`Vector`] values of this type to a different value. + /// + /// See [`Self`] for more detail. + fn transform(self, f: &mut F) -> Self; +} + +/// A transform operation for [`ScalarOrVectorComposite::transform`] +pub trait ScalarOrVectorTransform { + /// transform a [`ScalarOrVector`] + fn transform(&mut self, value: T) -> T; + + /// transform a [`Scalar`], defaults to [`self.transform`] + #[inline] + fn transform_scalar(&mut self, value: T) -> T { + self.transform(value) + } + + /// transform a [`Vector`], defaults to [`self.transform`] + #[inline] + fn transform_vector, S: Scalar, const N: usize>(&mut self, value: V) -> V { + self.transform(value) + } +} + +/// `Default` is unfortunately necessary until rust-gpu improves +impl ScalarOrVectorComposite for [T; N] { + #[inline] + fn transform(self, f: &mut F) -> Self { + let mut out = [T::default(); N]; + for i in 0..N { + out[i] = self[i].transform(f); + } + out + } +} diff --git a/crates/spirv-std/src/vector.rs b/crates/spirv-std/src/vector.rs index 0389424df0..c5464adfc2 100644 --- a/crates/spirv-std/src/vector.rs +++ b/crates/spirv-std/src/vector.rs @@ -1,7 +1,7 @@ //! Traits related to vectors. use crate::sealed::Sealed; -use crate::{Scalar, ScalarOrVector}; +use crate::{Scalar, ScalarOrVector, ScalarOrVectorComposite, ScalarOrVectorTransform}; use core::num::NonZeroUsize; use glam::{Vec3Swizzles, Vec4Swizzles}; @@ -57,6 +57,12 @@ macro_rules! impl_vector { ($($ty:ty: [$scalar:ty; $n:literal];)+) => { $( impl Sealed for $ty {} + impl ScalarOrVectorComposite for $ty { + #[inline] + fn transform(self, f: &mut F) -> Self { + f.transform_vector(self) + } + } unsafe impl ScalarOrVector for $ty { type Scalar = $scalar; const N: NonZeroUsize = NonZeroUsize::new($n).unwrap(); diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr index c292b79934..9761234110 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr +++ b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr @@ -1,5 +1,5 @@ error[E0080]: evaluation panicked: `ClusterSize` must be at least 1 - --> $SPIRV_STD_SRC/arch/subgroup.rs:868:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:927:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. @@ -13,7 +13,7 @@ LL | | "); = note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info) note: erroneous constant encountered - --> $SPIRV_STD_SRC/arch/subgroup.rs:868:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:927:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr index 61d066c3fc..52dbb005d7 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr +++ b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr @@ -1,5 +1,5 @@ error[E0080]: evaluation panicked: `ClusterSize` must be a power of 2 - --> $SPIRV_STD_SRC/arch/subgroup.rs:868:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:927:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. @@ -13,7 +13,7 @@ LL | | "); = note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info) note: erroneous constant encountered - --> $SPIRV_STD_SRC/arch/subgroup.rs:868:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:927:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. From 6c8a6718adf3217b06c1a85bce50d090d3c57da2 Mon Sep 17 00:00:00 2001 From: firestar99 Date: Tue, 14 Oct 2025 18:48:42 +0200 Subject: [PATCH 2/6] ScalarComposite: add `#[derive(ScalarOrVectorComposite)]` for structs --- crates/spirv-std/macros/src/lib.rs | 8 +++ .../macros/src/scalar_or_vector_composite.rs | 55 +++++++++++++++++++ crates/spirv-std/src/lib.rs | 1 + .../ui/arch/subgroup/subgroup_composite.rs | 54 ++++++++++++++++++ .../subgroup/subgroup_composite_shuffle.rs | 45 +++++++++++++++ .../subgroup_composite_shuffle.stderr | 16 ++++++ 6 files changed, 179 insertions(+) create mode 100644 crates/spirv-std/macros/src/scalar_or_vector_composite.rs create mode 100644 tests/compiletests/ui/arch/subgroup/subgroup_composite.rs create mode 100644 tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs create mode 100644 tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.stderr diff --git a/crates/spirv-std/macros/src/lib.rs b/crates/spirv-std/macros/src/lib.rs index d8ecf7b0cf..424c9ede5d 100644 --- a/crates/spirv-std/macros/src/lib.rs +++ b/crates/spirv-std/macros/src/lib.rs @@ -74,6 +74,7 @@ mod debug_printf; mod image; mod sample_param_permutations; +mod scalar_or_vector_composite; use crate::debug_printf::{DebugPrintfInput, debug_printf_inner}; use proc_macro::TokenStream; @@ -311,3 +312,10 @@ pub fn debug_printfln(input: TokenStream) -> TokenStream { pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream { sample_param_permutations::gen_sample_param_permutations(item) } + +#[proc_macro_derive(ScalarOrVectorComposite)] +pub fn derive_scalar_or_vector_composite(item: TokenStream) -> TokenStream { + scalar_or_vector_composite::derive(item.into()) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/crates/spirv-std/macros/src/scalar_or_vector_composite.rs b/crates/spirv-std/macros/src/scalar_or_vector_composite.rs new file mode 100644 index 0000000000..8b83ccad7c --- /dev/null +++ b/crates/spirv-std/macros/src/scalar_or_vector_composite.rs @@ -0,0 +1,55 @@ +use proc_macro2::TokenStream; +use quote::{ToTokens, quote}; +use syn::punctuated::Punctuated; +use syn::{Fields, FieldsNamed, FieldsUnnamed, GenericParam, Token}; + +pub fn derive(item: TokenStream) -> syn::Result { + // Whenever we'll properly resolve the crate symbol, replace this. + let spirv_std = quote!(spirv_std); + + // Defer all validation to our codegen backend. Rather than erroring here, emit garbage. + let item = syn::parse2::(item)?; + let struct_ident = &item.ident; + let gens = &item.generics.params; + let gen_refs = &item + .generics + .params + .iter() + .map(|p| match p { + GenericParam::Lifetime(p) => p.lifetime.to_token_stream(), + GenericParam::Type(p) => p.ident.to_token_stream(), + GenericParam::Const(p) => p.ident.to_token_stream(), + }) + .collect::>(); + let where_clause = &item.generics.where_clause; + + let content = + match item.fields { + Fields::Named(FieldsNamed { named, .. }) => { + let content = named.iter().map(|f| { + let ident = &f.ident; + quote!(#ident: #spirv_std::ScalarOrVectorComposite::transform(self.#ident, f)) + }).collect::>(); + quote!(Self { #content }) + } + Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => { + let content = (0..unnamed.len()) + .map(|i| { + let i = syn::Index::from(i); + quote!(#spirv_std::ScalarOrVectorComposite::transform(self.#i, f)) + }) + .collect::>(); + quote!(Self(#content)) + } + Fields::Unit => quote!(Self), + }; + + Ok(quote! { + impl<#gens> #spirv_std::ScalarOrVectorComposite for #struct_ident<#gen_refs> #where_clause { + #[inline] + fn transform(self, f: &mut F) -> Self { + #content + } + } + }) +} diff --git a/crates/spirv-std/src/lib.rs b/crates/spirv-std/src/lib.rs index 2c85dc9af0..2fdcd5610d 100644 --- a/crates/spirv-std/src/lib.rs +++ b/crates/spirv-std/src/lib.rs @@ -87,6 +87,7 @@ /// Public re-export of the `spirv-std-macros` crate. #[macro_use] pub extern crate spirv_std_macros as macros; +pub use macros::ScalarOrVectorComposite; pub use macros::spirv; pub use macros::{debug_printf, debug_printfln}; diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs new file mode 100644 index 0000000000..ca4829f521 --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs @@ -0,0 +1,54 @@ +// build-pass +// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot,+GroupNonUniformShuffle,+GroupNonUniformShuffleRelative,+ext:SPV_KHR_vulkan_memory_model +// normalize-stderr-test "OpLine .*\n" -> "" +// ignore-vulkan1.0 +// ignore-vulkan1.1 +// ignore-spv1.0 +// ignore-spv1.1 +// ignore-spv1.2 +// ignore-spv1.3 +// ignore-spv1.4 + +use glam::*; +use spirv_std::ScalarOrVectorComposite; +use spirv_std::arch::*; +use spirv_std::spirv; + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct MyStruct { + a: f32, + b: UVec3, + c: Nested, + d: Zst, +} + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Nested(i32); + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Zst; + +#[spirv(compute(threads(32)))] +pub fn main( + #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut UVec3, +) { + unsafe { + let my_struct = MyStruct { + a: 1., + b: inv_id, + c: Nested(-42), + d: Zst, + }; + + let mut out = UVec3::ZERO; + // before spv1.5 / vulkan1.2, this id = 19 must be a constant + out += subgroup_broadcast(my_struct, 19).b; + out += subgroup_broadcast_first(my_struct).b; + out += subgroup_shuffle(my_struct, 2).b; + out += subgroup_shuffle_xor(my_struct, 4).b; + out += subgroup_shuffle_up(my_struct, 5).b; + out += subgroup_shuffle_down(my_struct, 7).b; + *output = out; + } +} diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs new file mode 100644 index 0000000000..1009fb74ca --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs @@ -0,0 +1,45 @@ +// build-pass +// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformShuffle,+ext:SPV_KHR_vulkan_memory_model +// compile-flags: -C llvm-args=--disassemble-fn=subgroup_composite_shuffle::disassembly +// normalize-stderr-test "OpLine .*\n" -> "" + +use glam::*; +use spirv_std::ScalarOrVectorComposite; +use spirv_std::arch::*; +use spirv_std::spirv; + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct MyStruct { + a: f32, + b: UVec3, + c: Nested, + d: Zst, +} + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Nested(i32); + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Zst; + +/// this should be 3 `subgroup_shuffle` instructions, with all calls inlined +fn disassembly(my_struct: MyStruct, id: u32) -> MyStruct { + subgroup_shuffle(my_struct, id) +} + +#[spirv(compute(threads(32)))] +pub fn main( + #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut MyStruct, +) { + unsafe { + let my_struct = MyStruct { + a: inv_id.x as f32, + b: inv_id, + c: Nested(5i32 - inv_id.x as i32), + d: Zst, + }; + + *output = disassembly(my_struct, 5); + } +} diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.stderr new file mode 100644 index 0000000000..0127324087 --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.stderr @@ -0,0 +1,16 @@ +%1 = OpFunction %2 None %3 +%4 = OpFunctionParameter %2 +%5 = OpFunctionParameter %6 +%7 = OpLabel +%9 = OpCompositeExtract %10 %4 0 +%12 = OpGroupNonUniformShuffle %10 %13 %9 %5 +%14 = OpCompositeExtract %15 %4 1 +%16 = OpGroupNonUniformShuffle %15 %13 %14 %5 +%17 = OpCompositeExtract %18 %4 2 +%19 = OpGroupNonUniformShuffle %18 %13 %17 %5 +%20 = OpCompositeInsert %2 %12 %21 0 +%22 = OpCompositeInsert %2 %16 %20 1 +%23 = OpCompositeInsert %2 %19 %22 2 +OpNoLine +OpReturnValue %23 +OpFunctionEnd From 2e883b1fd981bcda98c488420644f50d5b81aedf Mon Sep 17 00:00:00 2001 From: firestar99 Date: Wed, 22 Oct 2025 18:41:51 +0200 Subject: [PATCH 3/6] ScalarComposite: adjust `subgroup_all_equal` to accept composites --- crates/spirv-std/src/arch/subgroup.rs | 39 ++++++++++------ .../subgroup_cluster_size_0_fail.stderr | 4 +- ..._cluster_size_non_power_of_two_fail.stderr | 4 +- .../subgroup/subgroup_composite_all_equals.rs | 46 +++++++++++++++++++ .../subgroup_composite_all_equals.stderr | 15 ++++++ 5 files changed, 90 insertions(+), 18 deletions(-) create mode 100644 tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs create mode 100644 tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.stderr diff --git a/crates/spirv-std/src/arch/subgroup.rs b/crates/spirv-std/src/arch/subgroup.rs index 50b2367899..5985a084c3 100644 --- a/crates/spirv-std/src/arch/subgroup.rs +++ b/crates/spirv-std/src/arch/subgroup.rs @@ -244,24 +244,35 @@ pub fn subgroup_any(predicate: bool) -> bool { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformAllEqual")] #[inline] -pub fn subgroup_all_equal(value: T) -> bool { - let mut result = false; +pub fn subgroup_all_equal(value: T) -> bool { + struct Transform(bool); - unsafe { - asm! { - "%bool = OpTypeBool", - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%result = OpGroupNonUniformAllEqual %bool %subgroup %value", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = false; + unsafe { + asm! { + "%bool = OpTypeBool", + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%result = OpGroupNonUniformAllEqual %bool %subgroup %value", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + result = in(reg) &mut result, + } + } + self.0 &= result; + value } } - result + let mut transform = Transform(true); + // ignore returned value + value.transform(&mut transform); + transform.0 } /// Result is the `value` of the invocation identified by the id `id` to all active invocations in the group. diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr index 9761234110..bc6d3a980f 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr +++ b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr @@ -1,5 +1,5 @@ error[E0080]: evaluation panicked: `ClusterSize` must be at least 1 - --> $SPIRV_STD_SRC/arch/subgroup.rs:927:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. @@ -13,7 +13,7 @@ LL | | "); = note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info) note: erroneous constant encountered - --> $SPIRV_STD_SRC/arch/subgroup.rs:927:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr index 52dbb005d7..e254fb228b 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr +++ b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr @@ -1,5 +1,5 @@ error[E0080]: evaluation panicked: `ClusterSize` must be a power of 2 - --> $SPIRV_STD_SRC/arch/subgroup.rs:927:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. @@ -13,7 +13,7 @@ LL | | "); = note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info) note: erroneous constant encountered - --> $SPIRV_STD_SRC/arch/subgroup.rs:927:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs new file mode 100644 index 0000000000..2c1c12f9aa --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs @@ -0,0 +1,46 @@ +// build-pass +// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformVote,+ext:SPV_KHR_vulkan_memory_model +// compile-flags: -C llvm-args=--disassemble-fn=subgroup_composite_all_equals::disassembly +// normalize-stderr-test "OpLine .*\n" -> "" + +use glam::*; +use spirv_std::ScalarOrVectorComposite; +use spirv_std::arch::*; +use spirv_std::spirv; + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct MyStruct { + a: f32, + b: UVec3, + c: Nested, + d: Zst, +} + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Nested(i32); + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Zst; + +/// this should be 3 `subgroup_all_equal` instructions, with all calls inlined +fn disassembly(my_struct: MyStruct) -> bool { + subgroup_all_equal(my_struct) +} + +#[spirv(compute(threads(32)))] +pub fn main( + #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut u32, +) { + unsafe { + let my_struct = MyStruct { + a: inv_id.x as f32, + b: inv_id, + c: Nested(5i32 - inv_id.x as i32), + d: Zst, + }; + + let bool = disassembly(my_struct); + *output = u32::from(bool); + } +} diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.stderr new file mode 100644 index 0000000000..d0167e9bed --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.stderr @@ -0,0 +1,15 @@ +%1 = OpFunction %2 None %3 +%4 = OpFunctionParameter %5 +%6 = OpLabel +%8 = OpCompositeExtract %9 %4 0 +%11 = OpGroupNonUniformAllEqual %2 %12 %8 +%13 = OpLogicalAnd %2 %14 %11 +%15 = OpCompositeExtract %16 %4 1 +%17 = OpGroupNonUniformAllEqual %2 %12 %15 +%18 = OpLogicalAnd %2 %13 %17 +%19 = OpCompositeExtract %20 %4 2 +%21 = OpGroupNonUniformAllEqual %2 %12 %19 +%22 = OpLogicalAnd %2 %18 %21 +OpNoLine +OpReturnValue %22 +OpFunctionEnd From 897169350540e24096f8a37780c461228bb01882 Mon Sep 17 00:00:00 2001 From: firestar99 Date: Thu, 16 Oct 2025 16:11:49 +0200 Subject: [PATCH 4/6] ScalarComposite: derive enums and improve enum docs --- .../macros/src/scalar_or_vector_composite.rs | 85 +++++++---- crates/spirv-std/src/scalar.rs | 4 + crates/spirv-std/src/scalar_or_vector.rs | 18 ++- .../arch/subgroup/subgroup_composite_enum.rs | 53 +++++++ .../subgroup/subgroup_composite_enum.stderr | 20 +++ .../subgroup/subgroup_composite_enum_err.rs | 89 ++++++++++++ .../subgroup_composite_enum_err.stderr | 136 ++++++++++++++++++ 7 files changed, 377 insertions(+), 28 deletions(-) create mode 100644 tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs create mode 100644 tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.stderr create mode 100644 tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs create mode 100644 tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr diff --git a/crates/spirv-std/macros/src/scalar_or_vector_composite.rs b/crates/spirv-std/macros/src/scalar_or_vector_composite.rs index 8b83ccad7c..f8b47c68fc 100644 --- a/crates/spirv-std/macros/src/scalar_or_vector_composite.rs +++ b/crates/spirv-std/macros/src/scalar_or_vector_composite.rs @@ -1,15 +1,26 @@ use proc_macro2::TokenStream; use quote::{ToTokens, quote}; use syn::punctuated::Punctuated; -use syn::{Fields, FieldsNamed, FieldsUnnamed, GenericParam, Token}; +use syn::{ + Data, DataStruct, DataUnion, DeriveInput, Fields, FieldsNamed, FieldsUnnamed, GenericParam, + Token, +}; pub fn derive(item: TokenStream) -> syn::Result { // Whenever we'll properly resolve the crate symbol, replace this. let spirv_std = quote!(spirv_std); // Defer all validation to our codegen backend. Rather than erroring here, emit garbage. - let item = syn::parse2::(item)?; - let struct_ident = &item.ident; + let item = syn::parse2::(item)?; + let content = match &item.data { + Data::Enum(_) => derive_enum(&spirv_std, &item), + Data::Struct(data) => derive_struct(&spirv_std, data), + Data::Union(DataUnion { union_token, .. }) => { + Err(syn::Error::new_spanned(union_token, "Union not supported")) + } + }?; + + let ident = &item.ident; let gens = &item.generics.params; let gen_refs = &item .generics @@ -23,29 +34,8 @@ pub fn derive(item: TokenStream) -> syn::Result { .collect::>(); let where_clause = &item.generics.where_clause; - let content = - match item.fields { - Fields::Named(FieldsNamed { named, .. }) => { - let content = named.iter().map(|f| { - let ident = &f.ident; - quote!(#ident: #spirv_std::ScalarOrVectorComposite::transform(self.#ident, f)) - }).collect::>(); - quote!(Self { #content }) - } - Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => { - let content = (0..unnamed.len()) - .map(|i| { - let i = syn::Index::from(i); - quote!(#spirv_std::ScalarOrVectorComposite::transform(self.#i, f)) - }) - .collect::>(); - quote!(Self(#content)) - } - Fields::Unit => quote!(Self), - }; - Ok(quote! { - impl<#gens> #spirv_std::ScalarOrVectorComposite for #struct_ident<#gen_refs> #where_clause { + impl<#gens> #spirv_std::ScalarOrVectorComposite for #ident<#gen_refs> #where_clause { #[inline] fn transform(self, f: &mut F) -> Self { #content @@ -53,3 +43,48 @@ pub fn derive(item: TokenStream) -> syn::Result { } }) } + +pub fn derive_struct(spirv_std: &TokenStream, data: &DataStruct) -> syn::Result { + Ok(match &data.fields { + Fields::Named(FieldsNamed { named, .. }) => { + let content = named + .iter() + .map(|f| { + let ident = &f.ident; + quote!(#ident: #spirv_std::ScalarOrVectorComposite::transform(self.#ident, f)) + }) + .collect::>(); + quote!(Self { #content }) + } + Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => { + let content = (0..unnamed.len()) + .map(|i| { + let i = syn::Index::from(i); + quote!(#spirv_std::ScalarOrVectorComposite::transform(self.#i, f)) + }) + .collect::>(); + quote!(Self(#content)) + } + Fields::Unit => quote!(Self), + }) +} + +pub fn derive_enum(spirv_std: &TokenStream, item: &DeriveInput) -> syn::Result { + let mut attributes = item.attrs.iter().filter(|a| a.path().is_ident("repr")); + let repr = match (attributes.next(), attributes.next()) { + (None, _) => Err(syn::Error::new_spanned( + item, + "Missing #[repr(...)] attribute", + )), + (Some(repr), None) => Ok(repr), + (Some(_), Some(_)) => Err(syn::Error::new_spanned( + item, + "Multiple #[repr(...)] attributes found", + )), + }?; + let prim = &repr.meta.require_list()?.tokens; + Ok(quote! { + #spirv_std::assert_is_integer::<#prim>(); + >::from(#spirv_std::ScalarOrVectorComposite::transform(>::into(self), f)) + }) +} diff --git a/crates/spirv-std/src/scalar.rs b/crates/spirv-std/src/scalar.rs index 363035dd49..52d4b365f2 100644 --- a/crates/spirv-std/src/scalar.rs +++ b/crates/spirv-std/src/scalar.rs @@ -117,3 +117,7 @@ impl_scalar! { impl Float for f64; impl Scalar for bool; } + +/// used by `ScalarOrVector` derive when working with enums +#[inline] +pub fn assert_is_integer() {} diff --git a/crates/spirv-std/src/scalar_or_vector.rs b/crates/spirv-std/src/scalar_or_vector.rs index a8497924e8..3e1b659fcf 100644 --- a/crates/spirv-std/src/scalar_or_vector.rs +++ b/crates/spirv-std/src/scalar_or_vector.rs @@ -31,11 +31,23 @@ pub unsafe trait ScalarOrVector: ScalarOrVectorComposite + Default { /// data to other threads. /// /// To derive `#[derive(VectorOrScalarComposite)]` on a struct, all members must also implement -/// `VectorOrScalarComposite`. To derive it on an enum, the enum must have `#[repr(N)]` where `N` is an [`Integer`]. -/// Additionally, you must derive `num_enum::FromPrimitive` and `num_enum::ToPrimitive`, which requires the enum to be -/// either exhaustive, implement [`Default`] or a variant of the enum to have the `#[num_enum(default)]` attribute. +/// `VectorOrScalarComposite`. +/// +/// To derive it on an enum, the enum must implement `From` and `Into` where `N` is defined by the `#[repr(N)]` +/// attribute on the enum and is an [`Integer`], like `u32`. +/// Note that some [safe subgroup operations] may return an "undefined result", so your `From` must gracefully handle +/// arbitrary bit patterns being passed to it. While panicking is legal, it is discouraged as it may result in +/// unexpected control flow. +/// To implement these conversion traits, we recommend [`FromPrimitive`] and [`IntoPrimitive`] from the [`num_enum`] +/// crate. [`FromPrimitive`] requires that either the enum is exhaustive, or you provide it with a variant to default +/// to, by either implementing [`Default`] or marking a variant with `#[num_enum(default)]`. Note to disable default +/// features on the [`num_enum`] crate, or it won't compile on SPIR-V. /// /// [`Integer`]: crate::Integer +/// [subgroup operations]: crate::arch::subgroup_shuffle +/// [`FromPrimitive`]: https://docs.rs/num_enum/latest/num_enum/derive.FromPrimitive.html +/// [`IntoPrimitive`]: https://docs.rs/num_enum/latest/num_enum/derive.IntoPrimitive.html +/// [`num_enum`]: https://crates.io/crates/num_enum pub trait ScalarOrVectorComposite: Copy + Send + Sync + 'static { /// Transform the individual [`Scalar`] and [`Vector`] values of this type to a different value. /// diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs new file mode 100644 index 0000000000..6daa901929 --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs @@ -0,0 +1,53 @@ +// build-pass +// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformShuffle,+ext:SPV_KHR_vulkan_memory_model +// compile-flags: -C llvm-args=--disassemble-fn=subgroup_composite_enum::disassembly +// normalize-stderr-test "OpLine .*\n" -> "" + +use glam::*; +use spirv_std::ScalarOrVectorComposite; +use spirv_std::arch::*; +use spirv_std::spirv; + +#[repr(u32)] +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum MyEnum { + #[default] + A, + B, + C, +} + +impl From for MyEnum { + #[inline] + fn from(value: u32) -> Self { + match value { + 0 => Self::A, + 1 => Self::B, + 2 => Self::C, + _ => Self::default(), + } + } +} + +impl From for u32 { + #[inline] + fn from(value: MyEnum) -> Self { + value as u32 + } +} + +/// this should be a single `subgroup_shuffle` instruction, with all calls inlined +fn disassembly(my_struct: MyEnum, id: u32) -> MyEnum { + subgroup_shuffle(my_struct, id) +} + +#[spirv(compute(threads(32)))] +pub fn main( + #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut MyEnum, +) { + unsafe { + let my_enum = MyEnum::from(inv_id.x % 3); + *output = disassembly(my_enum, 5); + } +} diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.stderr new file mode 100644 index 0000000000..091689e0fc --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.stderr @@ -0,0 +1,20 @@ +%1 = OpFunction %2 None %3 +%4 = OpFunctionParameter %2 +%5 = OpFunctionParameter %2 +%6 = OpLabel +%8 = OpGroupNonUniformShuffle %2 %9 %4 %5 +OpNoLine +OpSelectionMerge %10 None +OpSwitch %8 %11 0 %12 1 %13 2 %14 +%11 = OpLabel +OpBranch %10 +%12 = OpLabel +OpBranch %10 +%13 = OpLabel +OpBranch %10 +%14 = OpLabel +OpBranch %10 +%10 = OpLabel +%15 = OpPhi %2 %16 %11 %16 %12 %17 %13 %18 %14 +OpReturnValue %15 +OpFunctionEnd diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs new file mode 100644 index 0000000000..2706891ec2 --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs @@ -0,0 +1,89 @@ +// build-fail +// normalize-stderr-test "\S*/crates/spirv-std/src/" -> "$$SPIRV_STD_SRC/" +// normalize-stderr-test "\.rs:\d+:\d+" -> ".rs:" +// normalize-stderr-test "(\n)\d* *([ -])([\|\+\-\=])" -> "$1 $2$3" + +use glam::*; +use spirv_std::ScalarOrVectorComposite; +use spirv_std::arch::*; +use spirv_std::spirv; + +macro_rules! enum_repr_from { + ($ident:ident, $repr:ty) => { + impl From<$repr> for $ident { + #[inline] + fn from(value: $repr) -> Self { + match value { + 0 => Self::A, + 1 => Self::B, + 2 => Self::C, + _ => Self::default(), + } + } + } + + impl From<$ident> for $repr { + #[inline] + fn from(value: $ident) -> Self { + value as $repr + } + } + }; +} + +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum NoRepr { + #[default] + A, + B, + C, +} + +#[repr(u32)] +#[repr(u16)] +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum TwoRepr { + #[default] + A, + B, + C, +} + +#[repr(C)] +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum CRepr { + #[default] + A, + B, + C, +} + +#[repr(i32)] +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum NoFrom { + #[default] + A, + B, + C, +} + +#[repr(i32)] +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum WrongFrom { + #[default] + A, + B, + C, +} + +enum_repr_from!(WrongFrom, u32); + +#[repr(i32)] +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub enum NoDefault { + A, + B, + C, +} + +enum_repr_from!(NoDefault, i32); diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr new file mode 100644 index 0000000000..8665751b2a --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr @@ -0,0 +1,136 @@ +error: Missing #[repr(...)] attribute + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | / pub enum NoRepr { +LL | | #[default] +LL | | A, +LL | | B, +LL | | C, +LL | | } + | |_^ + +error: Multiple #[repr(...)] attributes found + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | / #[repr(u32)] +LL | | #[repr(u16)] +LL | | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +LL | | pub enum TwoRepr { +... | +LL | | C, +LL | | } + | |_^ + +error[E0412]: cannot find type `C` in this scope + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[repr(C)] + | ^ +LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] + | ----------------------- similarly named type parameter `F` defined here + | +help: there is an enum variant `crate::CRepr::C` and 6 others; try using the variant's enum + | +LL - #[repr(C)] +LL + #[repr(crate::CRepr)] + | +LL - #[repr(C)] +LL + #[repr(crate::NoDefault)] + | +LL - #[repr(C)] +LL + #[repr(crate::NoFrom)] + | +LL - #[repr(C)] +LL + #[repr(crate::NoRepr)] + | + and 2 other candidates +help: a type parameter with a similar name exists + | +LL - #[repr(C)] +LL + #[repr(F)] + | + +error[E0566]: conflicting representation hints + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[repr(u32)] + | ^^^ +LL | #[repr(u16)] + | ^^^ + | + = warning: this was previously accepted by the compiler but is being phased out; it will become a hard error in a future release! + = note: for more information, see issue #68585 + = note: `#[deny(conflicting_repr_hints)]` on by default + +error[E0277]: the trait bound `NoFrom: From` is not satisfied + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] + | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `NoFrom` + | + = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0277]: the trait bound `i32: From` is not satisfied + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] + | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `i32` + | + = help: the following other types implement trait `From`: + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + = note: required for `NoFrom` to implement `Into` + = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0277]: the trait bound `WrongFrom: From` is not satisfied + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] + | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `WrongFrom` + | + = help: the trait `From` is not implemented for `WrongFrom` + but trait `From` is implemented for it + = help: for that trait implementation, expected `u32`, found `i32` + = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0277]: the trait bound `i32: From` is not satisfied + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] + | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `i32` + | + = help: the following other types implement trait `From`: + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + = note: required for `WrongFrom` to implement `Into` + = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0599]: no variant or associated item named `default` found for enum `NoDefault` in the current scope + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | _ => Self::default(), + | ^^^^^^^ variant or associated item not found in `NoDefault` +... +LL | pub enum NoDefault { + | ------------------ variant or associated item `default` not found for this enum +... +LL | enum_repr_from!(NoDefault, i32); + | ------------------------------- in this macro invocation + | + = help: items from traits can only be used if the trait is implemented and in scope + = note: the following trait defines an item `default`, perhaps you need to implement it: + candidate #1: `Default` + = note: this error originates in the macro `enum_repr_from` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: aborting due to 9 previous errors + +Some errors have detailed explanations: E0277, E0412, E0566, E0599. +For more information about an error, try `rustc --explain E0277`. From dee359c0d0ae91f090b56109609e09f7dcfacecb Mon Sep 17 00:00:00 2001 From: firestar99 Date: Fri, 28 Nov 2025 12:56:36 +0100 Subject: [PATCH 5/6] ScalarComposite: rename `ScalarOrVectorComposite` to `ScalarComposite`, bulk rename --- crates/spirv-std/macros/src/lib.rs | 2 +- .../macros/src/scalar_or_vector_composite.rs | 8 ++--- crates/spirv-std/src/arch/subgroup.rs | 22 +++++++------- crates/spirv-std/src/lib.rs | 2 +- crates/spirv-std/src/scalar.rs | 4 +-- crates/spirv-std/src/scalar_or_vector.rs | 8 ++--- crates/spirv-std/src/vector.rs | 4 +-- .../subgroup_cluster_size_0_fail.stderr | 4 +-- ..._cluster_size_non_power_of_two_fail.stderr | 4 +-- .../ui/arch/subgroup/subgroup_composite.rs | 8 ++--- .../subgroup/subgroup_composite_all_equals.rs | 8 ++--- .../arch/subgroup/subgroup_composite_enum.rs | 4 +-- .../subgroup/subgroup_composite_enum_err.rs | 14 ++++----- .../subgroup_composite_enum_err.stderr | 30 +++++++++---------- .../subgroup/subgroup_composite_shuffle.rs | 8 ++--- 15 files changed, 64 insertions(+), 66 deletions(-) diff --git a/crates/spirv-std/macros/src/lib.rs b/crates/spirv-std/macros/src/lib.rs index 424c9ede5d..05d14af6ed 100644 --- a/crates/spirv-std/macros/src/lib.rs +++ b/crates/spirv-std/macros/src/lib.rs @@ -313,7 +313,7 @@ pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> T sample_param_permutations::gen_sample_param_permutations(item) } -#[proc_macro_derive(ScalarOrVectorComposite)] +#[proc_macro_derive(ScalarComposite)] pub fn derive_scalar_or_vector_composite(item: TokenStream) -> TokenStream { scalar_or_vector_composite::derive(item.into()) .unwrap_or_else(syn::Error::into_compile_error) diff --git a/crates/spirv-std/macros/src/scalar_or_vector_composite.rs b/crates/spirv-std/macros/src/scalar_or_vector_composite.rs index f8b47c68fc..58c8e9967a 100644 --- a/crates/spirv-std/macros/src/scalar_or_vector_composite.rs +++ b/crates/spirv-std/macros/src/scalar_or_vector_composite.rs @@ -35,7 +35,7 @@ pub fn derive(item: TokenStream) -> syn::Result { let where_clause = &item.generics.where_clause; Ok(quote! { - impl<#gens> #spirv_std::ScalarOrVectorComposite for #ident<#gen_refs> #where_clause { + impl<#gens> #spirv_std::ScalarComposite for #ident<#gen_refs> #where_clause { #[inline] fn transform(self, f: &mut F) -> Self { #content @@ -51,7 +51,7 @@ pub fn derive_struct(spirv_std: &TokenStream, data: &DataStruct) -> syn::Result< .iter() .map(|f| { let ident = &f.ident; - quote!(#ident: #spirv_std::ScalarOrVectorComposite::transform(self.#ident, f)) + quote!(#ident: #spirv_std::ScalarComposite::transform(self.#ident, f)) }) .collect::>(); quote!(Self { #content }) @@ -60,7 +60,7 @@ pub fn derive_struct(spirv_std: &TokenStream, data: &DataStruct) -> syn::Result< let content = (0..unnamed.len()) .map(|i| { let i = syn::Index::from(i); - quote!(#spirv_std::ScalarOrVectorComposite::transform(self.#i, f)) + quote!(#spirv_std::ScalarComposite::transform(self.#i, f)) }) .collect::>(); quote!(Self(#content)) @@ -85,6 +85,6 @@ pub fn derive_enum(spirv_std: &TokenStream, item: &DeriveInput) -> syn::Result(); - >::from(#spirv_std::ScalarOrVectorComposite::transform(>::into(self), f)) + >::from(#spirv_std::ScalarComposite::transform(>::into(self), f)) }) } diff --git a/crates/spirv-std/src/arch/subgroup.rs b/crates/spirv-std/src/arch/subgroup.rs index 5985a084c3..6b92b99259 100644 --- a/crates/spirv-std/src/arch/subgroup.rs +++ b/crates/spirv-std/src/arch/subgroup.rs @@ -4,9 +4,7 @@ use crate::ScalarOrVectorTransform; use crate::arch::{asm, barrier}; #[cfg(target_arch = "spirv")] use crate::memory::{Scope, Semantics}; -use crate::{ - Float, Integer, ScalarOrVector, ScalarOrVectorComposite, SignedInteger, UnsignedInteger, -}; +use crate::{Float, Integer, ScalarComposite, ScalarOrVector, SignedInteger, UnsignedInteger}; #[cfg(target_arch = "spirv")] const SUBGROUP: u32 = Scope::Subgroup as u32; @@ -244,7 +242,7 @@ pub fn subgroup_any(predicate: bool) -> bool { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformAllEqual")] #[inline] -pub fn subgroup_all_equal(value: T) -> bool { +pub fn subgroup_all_equal(value: T) -> bool { struct Transform(bool); impl ScalarOrVectorTransform for Transform { @@ -299,7 +297,7 @@ pub fn subgroup_all_equal(value: T) -> bool { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformBroadcast")] #[inline] -pub unsafe fn subgroup_broadcast(value: T, id: u32) -> T { +pub unsafe fn subgroup_broadcast(value: T, id: u32) -> T { struct Transform { id: u32, } @@ -390,7 +388,7 @@ pub unsafe fn subgroup_broadcast_const(value: #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformBroadcastFirst")] #[inline] -pub fn subgroup_broadcast_first(value: T) -> T { +pub fn subgroup_broadcast_first(value: T) -> T { struct Transform; impl ScalarOrVectorTransform for Transform { @@ -672,7 +670,7 @@ pub fn subgroup_ballot_find_msb(value: SubgroupMask) -> u32 { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffle")] #[inline] -pub fn subgroup_shuffle(value: T, id: u32) -> T { +pub fn subgroup_shuffle(value: T, id: u32) -> T { struct Transform { id: u32, } @@ -722,7 +720,7 @@ pub fn subgroup_shuffle(value: T, id: u32) -> T { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffleXor")] #[inline] -pub fn subgroup_shuffle_xor(value: T, mask: u32) -> T { +pub fn subgroup_shuffle_xor(value: T, mask: u32) -> T { struct Transform { mask: u32, } @@ -772,7 +770,7 @@ pub fn subgroup_shuffle_xor(value: T, mask: u32) -> #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffleUp")] #[inline] -pub fn subgroup_shuffle_up(value: T, delta: u32) -> T { +pub fn subgroup_shuffle_up(value: T, delta: u32) -> T { struct Transform { delta: u32, } @@ -822,7 +820,7 @@ pub fn subgroup_shuffle_up(value: T, delta: u32) -> #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffleDown")] #[inline] -pub fn subgroup_shuffle_down(value: T, delta: u32) -> T { +pub fn subgroup_shuffle_down(value: T, delta: u32) -> T { struct Transform { delta: u32, } @@ -1458,7 +1456,7 @@ Requires Capability `GroupNonUniformArithmetic` and `GroupNonUniformClustered`. #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformQuadBroadcast")] #[inline] -pub fn subgroup_quad_broadcast(value: T, index: u32) -> T { +pub fn subgroup_quad_broadcast(value: T, index: u32) -> T { struct Transform { index: u32, } @@ -1550,7 +1548,7 @@ pub enum QuadDirection { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformQuadSwap")] #[inline] -pub fn subgroup_quad_swap(value: T) -> T { +pub fn subgroup_quad_swap(value: T) -> T { struct Transform; impl ScalarOrVectorTransform for Transform { diff --git a/crates/spirv-std/src/lib.rs b/crates/spirv-std/src/lib.rs index 2fdcd5610d..6e178fe874 100644 --- a/crates/spirv-std/src/lib.rs +++ b/crates/spirv-std/src/lib.rs @@ -87,7 +87,7 @@ /// Public re-export of the `spirv-std-macros` crate. #[macro_use] pub extern crate spirv_std_macros as macros; -pub use macros::ScalarOrVectorComposite; +pub use macros::ScalarComposite; pub use macros::spirv; pub use macros::{debug_printf, debug_printfln}; diff --git a/crates/spirv-std/src/scalar.rs b/crates/spirv-std/src/scalar.rs index 52d4b365f2..38c791379b 100644 --- a/crates/spirv-std/src/scalar.rs +++ b/crates/spirv-std/src/scalar.rs @@ -1,7 +1,7 @@ //! Traits related to scalars. use crate::sealed::Sealed; -use crate::{ScalarOrVector, ScalarOrVectorComposite, ScalarOrVectorTransform}; +use crate::{ScalarComposite, ScalarOrVector, ScalarOrVectorTransform}; use core::num::NonZeroUsize; /// Abstract trait representing a SPIR-V scalar type, which includes: @@ -61,7 +61,7 @@ pub unsafe trait Float: num_traits::Float + Number { macro_rules! impl_scalar { (impl Scalar for $ty:ty;) => { impl Sealed for $ty {} - impl ScalarOrVectorComposite for $ty { + impl ScalarComposite for $ty { #[inline] fn transform(self, f: &mut F) -> Self { f.transform_scalar(self) diff --git a/crates/spirv-std/src/scalar_or_vector.rs b/crates/spirv-std/src/scalar_or_vector.rs index 3e1b659fcf..bb94f4d3c5 100644 --- a/crates/spirv-std/src/scalar_or_vector.rs +++ b/crates/spirv-std/src/scalar_or_vector.rs @@ -11,7 +11,7 @@ pub(crate) mod sealed { /// /// # Safety /// Your type must also implement [`Scalar`] or [`Vector`], see their safety sections as well. -pub unsafe trait ScalarOrVector: ScalarOrVectorComposite + Default { +pub unsafe trait ScalarOrVector: ScalarComposite + Default { /// Either the scalar component type of the vector or the scalar itself. type Scalar: Scalar; @@ -48,14 +48,14 @@ pub unsafe trait ScalarOrVector: ScalarOrVectorComposite + Default { /// [`FromPrimitive`]: https://docs.rs/num_enum/latest/num_enum/derive.FromPrimitive.html /// [`IntoPrimitive`]: https://docs.rs/num_enum/latest/num_enum/derive.IntoPrimitive.html /// [`num_enum`]: https://crates.io/crates/num_enum -pub trait ScalarOrVectorComposite: Copy + Send + Sync + 'static { +pub trait ScalarComposite: Copy + Send + Sync + 'static { /// Transform the individual [`Scalar`] and [`Vector`] values of this type to a different value. /// /// See [`Self`] for more detail. fn transform(self, f: &mut F) -> Self; } -/// A transform operation for [`ScalarOrVectorComposite::transform`] +/// A transform operation for [`ScalarComposite::transform`] pub trait ScalarOrVectorTransform { /// transform a [`ScalarOrVector`] fn transform(&mut self, value: T) -> T; @@ -74,7 +74,7 @@ pub trait ScalarOrVectorTransform { } /// `Default` is unfortunately necessary until rust-gpu improves -impl ScalarOrVectorComposite for [T; N] { +impl ScalarComposite for [T; N] { #[inline] fn transform(self, f: &mut F) -> Self { let mut out = [T::default(); N]; diff --git a/crates/spirv-std/src/vector.rs b/crates/spirv-std/src/vector.rs index c5464adfc2..e0b6a86bb0 100644 --- a/crates/spirv-std/src/vector.rs +++ b/crates/spirv-std/src/vector.rs @@ -1,7 +1,7 @@ //! Traits related to vectors. use crate::sealed::Sealed; -use crate::{Scalar, ScalarOrVector, ScalarOrVectorComposite, ScalarOrVectorTransform}; +use crate::{Scalar, ScalarComposite, ScalarOrVector, ScalarOrVectorTransform}; use core::num::NonZeroUsize; use glam::{Vec3Swizzles, Vec4Swizzles}; @@ -57,7 +57,7 @@ macro_rules! impl_vector { ($($ty:ty: [$scalar:ty; $n:literal];)+) => { $( impl Sealed for $ty {} - impl ScalarOrVectorComposite for $ty { + impl ScalarComposite for $ty { #[inline] fn transform(self, f: &mut F) -> Self { f.transform_vector(self) diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr index bc6d3a980f..d5f06bdb4b 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr +++ b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr @@ -1,5 +1,5 @@ error[E0080]: evaluation panicked: `ClusterSize` must be at least 1 - --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:937:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. @@ -13,7 +13,7 @@ LL | | "); = note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info) note: erroneous constant encountered - --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:937:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr index e254fb228b..db300c3e12 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr +++ b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr @@ -1,5 +1,5 @@ error[E0080]: evaluation panicked: `ClusterSize` must be a power of 2 - --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:937:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. @@ -13,7 +13,7 @@ LL | | "); = note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info) note: erroneous constant encountered - --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:937:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs index ca4829f521..5db5ba963e 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs @@ -10,11 +10,11 @@ // ignore-spv1.4 use glam::*; -use spirv_std::ScalarOrVectorComposite; +use spirv_std::ScalarComposite; use spirv_std::arch::*; use spirv_std::spirv; -#[derive(Copy, Clone, ScalarOrVectorComposite)] +#[derive(Copy, Clone, ScalarComposite)] pub struct MyStruct { a: f32, b: UVec3, @@ -22,10 +22,10 @@ pub struct MyStruct { d: Zst, } -#[derive(Copy, Clone, ScalarOrVectorComposite)] +#[derive(Copy, Clone, ScalarComposite)] pub struct Nested(i32); -#[derive(Copy, Clone, ScalarOrVectorComposite)] +#[derive(Copy, Clone, ScalarComposite)] pub struct Zst; #[spirv(compute(threads(32)))] diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs index 2c1c12f9aa..953a4429e8 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs @@ -4,11 +4,11 @@ // normalize-stderr-test "OpLine .*\n" -> "" use glam::*; -use spirv_std::ScalarOrVectorComposite; +use spirv_std::ScalarComposite; use spirv_std::arch::*; use spirv_std::spirv; -#[derive(Copy, Clone, ScalarOrVectorComposite)] +#[derive(Copy, Clone, ScalarComposite)] pub struct MyStruct { a: f32, b: UVec3, @@ -16,10 +16,10 @@ pub struct MyStruct { d: Zst, } -#[derive(Copy, Clone, ScalarOrVectorComposite)] +#[derive(Copy, Clone, ScalarComposite)] pub struct Nested(i32); -#[derive(Copy, Clone, ScalarOrVectorComposite)] +#[derive(Copy, Clone, ScalarComposite)] pub struct Zst; /// this should be 3 `subgroup_all_equal` instructions, with all calls inlined diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs index 6daa901929..ada500fb06 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs @@ -4,12 +4,12 @@ // normalize-stderr-test "OpLine .*\n" -> "" use glam::*; -use spirv_std::ScalarOrVectorComposite; +use spirv_std::ScalarComposite; use spirv_std::arch::*; use spirv_std::spirv; #[repr(u32)] -#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +#[derive(Copy, Clone, Default, ScalarComposite)] pub enum MyEnum { #[default] A, diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs index 2706891ec2..0579efeeaf 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs @@ -4,7 +4,7 @@ // normalize-stderr-test "(\n)\d* *([ -])([\|\+\-\=])" -> "$1 $2$3" use glam::*; -use spirv_std::ScalarOrVectorComposite; +use spirv_std::ScalarComposite; use spirv_std::arch::*; use spirv_std::spirv; @@ -31,7 +31,7 @@ macro_rules! enum_repr_from { }; } -#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +#[derive(Copy, Clone, Default, ScalarComposite)] pub enum NoRepr { #[default] A, @@ -41,7 +41,7 @@ pub enum NoRepr { #[repr(u32)] #[repr(u16)] -#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +#[derive(Copy, Clone, Default, ScalarComposite)] pub enum TwoRepr { #[default] A, @@ -50,7 +50,7 @@ pub enum TwoRepr { } #[repr(C)] -#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +#[derive(Copy, Clone, Default, ScalarComposite)] pub enum CRepr { #[default] A, @@ -59,7 +59,7 @@ pub enum CRepr { } #[repr(i32)] -#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +#[derive(Copy, Clone, Default, ScalarComposite)] pub enum NoFrom { #[default] A, @@ -68,7 +68,7 @@ pub enum NoFrom { } #[repr(i32)] -#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +#[derive(Copy, Clone, Default, ScalarComposite)] pub enum WrongFrom { #[default] A, @@ -79,7 +79,7 @@ pub enum WrongFrom { enum_repr_from!(WrongFrom, u32); #[repr(i32)] -#[derive(Copy, Clone, ScalarOrVectorComposite)] +#[derive(Copy, Clone, ScalarComposite)] pub enum NoDefault { A, B, diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr index 8665751b2a..4d0c7ce481 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr @@ -14,7 +14,7 @@ error: Multiple #[repr(...)] attributes found | LL | / #[repr(u32)] LL | | #[repr(u16)] -LL | | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +LL | | #[derive(Copy, Clone, Default, ScalarComposite)] LL | | pub enum TwoRepr { ... | LL | | C, @@ -26,8 +26,8 @@ error[E0412]: cannot find type `C` in this scope | LL | #[repr(C)] | ^ -LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] - | ----------------------- similarly named type parameter `F` defined here +LL | #[derive(Copy, Clone, Default, ScalarComposite)] + | --------------- similarly named type parameter `F` defined here | help: there is an enum variant `crate::CRepr::C` and 6 others; try using the variant's enum | @@ -65,16 +65,16 @@ LL | #[repr(u16)] error[E0277]: the trait bound `NoFrom: From` is not satisfied --> $DIR/subgroup_composite_enum_err.rs: | -LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] - | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `NoFrom` +LL | #[derive(Copy, Clone, Default, ScalarComposite)] + | ^^^^^^^^^^^^^^^ the trait `From` is not implemented for `NoFrom` | - = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the derive macro `ScalarComposite` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0277]: the trait bound `i32: From` is not satisfied --> $DIR/subgroup_composite_enum_err.rs: | -LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] - | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `i32` +LL | #[derive(Copy, Clone, Default, ScalarComposite)] + | ^^^^^^^^^^^^^^^ the trait `From` is not implemented for `i32` | = help: the following other types implement trait `From`: `i32` implements `From` @@ -84,24 +84,24 @@ LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] `i32` implements `From` `i32` implements `From` = note: required for `NoFrom` to implement `Into` - = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the derive macro `ScalarComposite` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0277]: the trait bound `WrongFrom: From` is not satisfied --> $DIR/subgroup_composite_enum_err.rs: | -LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] - | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `WrongFrom` +LL | #[derive(Copy, Clone, Default, ScalarComposite)] + | ^^^^^^^^^^^^^^^ the trait `From` is not implemented for `WrongFrom` | = help: the trait `From` is not implemented for `WrongFrom` but trait `From` is implemented for it = help: for that trait implementation, expected `u32`, found `i32` - = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the derive macro `ScalarComposite` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0277]: the trait bound `i32: From` is not satisfied --> $DIR/subgroup_composite_enum_err.rs: | -LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] - | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `i32` +LL | #[derive(Copy, Clone, Default, ScalarComposite)] + | ^^^^^^^^^^^^^^^ the trait `From` is not implemented for `i32` | = help: the following other types implement trait `From`: `i32` implements `From` @@ -111,7 +111,7 @@ LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] `i32` implements `From` `i32` implements `From` = note: required for `WrongFrom` to implement `Into` - = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the derive macro `ScalarComposite` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0599]: no variant or associated item named `default` found for enum `NoDefault` in the current scope --> $DIR/subgroup_composite_enum_err.rs: diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs index 1009fb74ca..a40c26175a 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs @@ -4,11 +4,11 @@ // normalize-stderr-test "OpLine .*\n" -> "" use glam::*; -use spirv_std::ScalarOrVectorComposite; +use spirv_std::ScalarComposite; use spirv_std::arch::*; use spirv_std::spirv; -#[derive(Copy, Clone, ScalarOrVectorComposite)] +#[derive(Copy, Clone, ScalarComposite)] pub struct MyStruct { a: f32, b: UVec3, @@ -16,10 +16,10 @@ pub struct MyStruct { d: Zst, } -#[derive(Copy, Clone, ScalarOrVectorComposite)] +#[derive(Copy, Clone, ScalarComposite)] pub struct Nested(i32); -#[derive(Copy, Clone, ScalarOrVectorComposite)] +#[derive(Copy, Clone, ScalarComposite)] pub struct Zst; /// this should be 3 `subgroup_shuffle` instructions, with all calls inlined From 95e557ce0fad154204887a4c7f52addd6d067ef5 Mon Sep 17 00:00:00 2001 From: firestar99 Date: Fri, 28 Nov 2025 13:11:10 +0100 Subject: [PATCH 6/6] ScalarComposite: adjust docs to the new name --- crates/spirv-std/src/scalar_or_vector.rs | 25 ++++++++++++------------ crates/spirv-std/src/vector.rs | 2 +- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/crates/spirv-std/src/scalar_or_vector.rs b/crates/spirv-std/src/scalar_or_vector.rs index bb94f4d3c5..43ad56d455 100644 --- a/crates/spirv-std/src/scalar_or_vector.rs +++ b/crates/spirv-std/src/scalar_or_vector.rs @@ -19,32 +19,31 @@ pub unsafe trait ScalarOrVector: ScalarComposite + Default { const N: NonZeroUsize; } -/// A `VectorOrScalarComposite` is a type that is either +/// A `ScalarComposite` is a type that is either /// * a [`Scalar`] -/// * a [`Vector`] -/// * an array of `VectorOrScalarComposite` -/// * a struct where all members are `VectorOrScalarComposite` +/// * a [`Vector`] (since vectors are made from scalars) +/// * an array of `ScalarComposite` +/// * a struct where all members are `ScalarComposite` /// * an enum with a `repr` that is a [`Scalar`] /// /// By calling [`Self::transform`] you can visit all the individual [`Scalar`] and [`Vector`] values this composite is /// build out of and transform them into some other value. This is particularly useful for subgroup intrinsics sending /// data to other threads. /// -/// To derive `#[derive(VectorOrScalarComposite)]` on a struct, all members must also implement -/// `VectorOrScalarComposite`. +/// To derive `ScalarComposite` on a struct, all members must also implement `ScalarComposite`. /// -/// To derive it on an enum, the enum must implement `From` and `Into` where `N` is defined by the `#[repr(N)]` -/// attribute on the enum and is an [`Integer`], like `u32`. +/// To derive `ScalarComposite` on an enum, the enum must implement `From` and `Into` where `N` is defined by the +/// `#[repr(N)]` attribute on the enum and must be an [`Integer`], like `u32`. /// Note that some [safe subgroup operations] may return an "undefined result", so your `From` must gracefully handle /// arbitrary bit patterns being passed to it. While panicking is legal, it is discouraged as it may result in /// unexpected control flow. /// To implement these conversion traits, we recommend [`FromPrimitive`] and [`IntoPrimitive`] from the [`num_enum`] -/// crate. [`FromPrimitive`] requires that either the enum is exhaustive, or you provide it with a variant to default -/// to, by either implementing [`Default`] or marking a variant with `#[num_enum(default)]`. Note to disable default +/// crate. [`FromPrimitive`] requires the enum to either be exhaustive or have a variant to default to, by either +/// implementing [`Default`] or marking a variant with `#[num_enum(default)]`. Note to disable default /// features on the [`num_enum`] crate, or it won't compile on SPIR-V. /// /// [`Integer`]: crate::Integer -/// [subgroup operations]: crate::arch::subgroup_shuffle +/// [safe subgroup operations]: crate::arch::subgroup_shuffle /// [`FromPrimitive`]: https://docs.rs/num_enum/latest/num_enum/derive.FromPrimitive.html /// [`IntoPrimitive`]: https://docs.rs/num_enum/latest/num_enum/derive.IntoPrimitive.html /// [`num_enum`]: https://crates.io/crates/num_enum @@ -60,13 +59,13 @@ pub trait ScalarOrVectorTransform { /// transform a [`ScalarOrVector`] fn transform(&mut self, value: T) -> T; - /// transform a [`Scalar`], defaults to [`self.transform`] + /// transform a [`Scalar`], defaults to [`Self::transform`] #[inline] fn transform_scalar(&mut self, value: T) -> T { self.transform(value) } - /// transform a [`Vector`], defaults to [`self.transform`] + /// transform a [`Vector`], defaults to [`Self::transform`] #[inline] fn transform_vector, S: Scalar, const N: usize>(&mut self, value: V) -> V { self.transform(value) diff --git a/crates/spirv-std/src/vector.rs b/crates/spirv-std/src/vector.rs index e0b6a86bb0..bd8a952ba7 100644 --- a/crates/spirv-std/src/vector.rs +++ b/crates/spirv-std/src/vector.rs @@ -45,7 +45,7 @@ use glam::{Vec3Swizzles, Vec4Swizzles}; /// # Safety /// * Must only be implemented on types that the spirv codegen emits as valid `OpTypeVector`. This includes all structs /// marked with `#[rust_gpu::vector::v1]`, like [`glam`]'s non-SIMD "scalar" vector types. -/// * `VectorOrScalar::DIM == N`, since const equality is behind rustc feature `associated_const_equality` +/// * `ScalarOrVector::DIM == N`, since const equality is behind rustc feature `associated_const_equality` // Note(@firestar99) I would like to have these two generics be associated types instead. Doesn't make much sense for // a vector type to implement this interface multiple times with different Scalar types or N, after all. // While it's possible with `T: Scalar`, it's not with `const N: usize`, since some impl blocks in `image::params` need