Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ pub enum SpirvAttribute {
Builtin(BuiltIn),
DescriptorSet(u32),
Binding(u32),
Location(u32),
Flat,
PerPrimitiveExt,
Invariant,
Expand Down Expand Up @@ -130,6 +131,7 @@ pub struct AggregatedSpirvAttributes {
pub builtin: Option<Spanned<BuiltIn>>,
pub descriptor_set: Option<Spanned<u32>>,
pub binding: Option<Spanned<u32>>,
pub location: Option<Spanned<u32>>,
pub flat: Option<Spanned<()>>,
pub invariant: Option<Spanned<()>>,
pub per_primitive_ext: Option<Spanned<()>>,
Expand Down Expand Up @@ -216,6 +218,7 @@ impl AggregatedSpirvAttributes {
"#[spirv(descriptor_set)]",
),
Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
Location(value) => try_insert(&mut self.location, value, span, "#[spirv(location)]"),
Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
PerPrimitiveExt => try_insert(
Expand Down Expand Up @@ -323,6 +326,7 @@ impl CheckSpirvAttrVisitor<'_> {
| SpirvAttribute::Builtin(_)
| SpirvAttribute::DescriptorSet(_)
| SpirvAttribute::Binding(_)
| SpirvAttribute::Location(_)
| SpirvAttribute::Flat
| SpirvAttribute::Invariant
| SpirvAttribute::PerPrimitiveExt
Expand Down Expand Up @@ -602,6 +606,8 @@ fn parse_spirv_attr<'a>(
SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.binding) {
SpirvAttribute::Binding(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.location) {
SpirvAttribute::Location(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.input_attachment_index) {
SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.spec_constant) {
Expand Down
81 changes: 66 additions & 15 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,41 @@ impl<'tcx> CodegenCx<'tcx> {
.name(var_id.or(spec_const_id).unwrap(), ident.to_string());
}

// location assignment
// Note(@firestar99): UniformConstant are things like `SampledImage`, `StorageImage`, `Sampler` and
// `Acceleration structure`. Almost always they are assigned a `descriptor_set` and binding, thus never end up
// here being assigned locations. I think this is one of those occasions where spirv allows us to assign
// locations, but the "client API" Vulkan doesn't describe any use-case for them, or at least none I'm aware of.
// A quick scour through the spec revealed that `VK_KHR_dynamic_rendering_local_read` may need this, and while
// we don't support it yet (I assume), I'll just keep it here in case it becomes useful in the future.
let has_location = matches!(
storage_class,
Ok(StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant)
);
let mut assign_location = |var_id: Result<Word, &str>, explicit: Option<u32>| {
let location = decoration_locations
.entry(storage_class.unwrap())
.or_insert_with(|| 0);
if let Some(explicit) = explicit {
*location = explicit;
}
self.emit_global().decorate(
var_id.unwrap(),
Decoration::Location,
std::iter::once(Operand::LiteralBit32(*location)),
);
let spirv_type = self.lookup_type(value_spirv_type);
if let Some(location_size) = spirv_type.location_size(self) {
*location += location_size;
} else {
*location += 1;
self.tcx.dcx().span_err(
hir_param.ty_span,
"Type not supported in Input or Output declarations",
);
}
};

// Emit `OpDecorate`s based on attributes.
let mut decoration_supersedes_location = false;
if let Some(builtin) = attrs.builtin {
Expand Down Expand Up @@ -757,6 +792,35 @@ impl<'tcx> CodegenCx<'tcx> {
);
decoration_supersedes_location = true;
}
if let Some(location) = attrs.location {
if let Err(SpecConstant { .. }) = storage_class {
self.tcx.dcx().span_fatal(
location.span,
"`#[spirv(location = ...)]` cannot apply to `#[spirv(spec_constant)]`",
);
}
if attrs.descriptor_set.is_some() {
self.tcx.dcx().span_fatal(
location.span,
"`#[spirv(location = ...)]` cannot be combined with `#[spirv(descriptor_set = ...)]`",
);
}
if attrs.binding.is_some() {
self.tcx.dcx().span_fatal(
location.span,
"`#[spirv(location = ...)]` cannot be combined with `#[spirv(binding = ...)]`",
);
}
if !has_location {
self.tcx.dcx().span_fatal(
location.span,
"`#[spirv(location = ...)]` can only be used on Inputs (declared as plain values, eg. `Vec4`)\
or Outputs (declared as mut ref, eg. `&mut Vec4`)",
);
}
assign_location(var_id, Some(location.value));
decoration_supersedes_location = true;
}
if let Some(flat) = attrs.flat {
if let Err(SpecConstant { .. }) = storage_class {
self.tcx.dcx().span_fatal(
Expand Down Expand Up @@ -867,21 +931,8 @@ impl<'tcx> CodegenCx<'tcx> {
// individually.
// TODO: Is this right for UniformConstant? Do they share locations with
// input/outpus?
let has_location = !decoration_supersedes_location
&& matches!(
storage_class,
Ok(StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant)
);
if has_location {
let location = decoration_locations
.entry(storage_class.unwrap())
.or_insert_with(|| 0);
self.emit_global().decorate(
var_id.unwrap(),
Decoration::Location,
std::iter::once(Operand::LiteralBit32(*location)),
);
*location += 1;
if !decoration_supersedes_location && has_location {
assign_location(var_id, None);
}

match storage_class {
Expand Down
37 changes: 31 additions & 6 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,34 @@ impl SpirvType<'_> {
id
}

/// Returns how many Input / Output `location`s this type occupies, or None if this type is not allowed to be sent.
///
/// See [Vulkan Spec 16.1.4. Location and Component Assignment](https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#interfaces-iointerfaces-locations)
#[allow(clippy::match_same_arms)]
pub fn location_size(&self, cx: &CodegenCx<'_>) -> Option<u32> {
let result = match *self {
// bools cannot be in an Input / Output interface
Self::Bool => return None,
Self::Integer(_, _) | Self::Float(_) => 1,
Self::Vector { .. } => 1,
Self::Adt { field_types, .. } => {
let mut locations = 0;
for f in field_types {
locations += cx.lookup_type(*f).location_size(cx)?;
}
locations
}
Self::Matrix { element, count } => cx.lookup_type(element).location_size(cx)? * count,
Self::Array { element, count } => {
let count = cx.builder.lookup_const_scalar(count).unwrap();
let count: u32 = count.try_into().unwrap();
cx.lookup_type(element).location_size(cx)? * count
}
_ => return None,
};
Some(result)
}

pub fn sizeof(&self, cx: &CodegenCx<'_>) -> Option<Size> {
let result = match *self {
// Types that have a dynamic size, or no concept of size at all.
Expand All @@ -287,12 +315,9 @@ impl SpirvType<'_> {
Self::Vector { size, .. } => size,
Self::Matrix { element, count } => cx.lookup_type(element).sizeof(cx)? * count as u64,
Self::Array { element, count } => {
cx.lookup_type(element).sizeof(cx)?
* cx.builder
.lookup_const_scalar(count)
.unwrap()
.try_into()
.unwrap()
let count = cx.builder.lookup_const_scalar(count).unwrap();
let count = count.try_into().unwrap();
cx.lookup_type(element).sizeof(cx)? * count
}
Self::Pointer { .. } => cx.tcx.data_layout.pointer_size,
Self::Image { .. }
Expand Down
2 changes: 2 additions & 0 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct Symbols {

pub descriptor_set: Symbol,
pub binding: Symbol,
pub location: Symbol,
pub input_attachment_index: Symbol,

pub spec_constant: Symbol,
Expand Down Expand Up @@ -420,6 +421,7 @@ impl Symbols {

descriptor_set: Symbol::intern("descriptor_set"),
binding: Symbol::intern("binding"),
location: Symbol::intern("location"),
input_attachment_index: Symbol::intern("input_attachment_index"),

spec_constant: Symbol::intern("spec_constant"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@ LL | #![feature(ptr_internals)]
= note: using it is strongly discouraged
= note: `#[warn(internal_features)]` on by default

error: pointer has non-null integer address
|
note: used from within `allocate_const_scalar::main`
--> $DIR/allocate_const_scalar.rs:16:5
|
LL | *output = POINTER;
| ^^^^^^^^^^^^^^^^^
note: called by `main`
--> $DIR/allocate_const_scalar.rs:15:8
error: Type not supported in Input or Output declarations
--> $DIR/allocate_const_scalar.rs:15:21
|
LL | pub fn main(output: &mut Unique<()>) {
| ^^^^
| ^^^^^^^^^^^^^^^

error: aborting due to 1 previous error; 1 warning emitted

Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use spirv_std::{Image, spirv};

#[spirv(vertex)]
pub fn main(
#[spirv(uniform)] error: &Image!(2D, type=f32),
#[spirv(uniform_constant)] warning: &Image!(2D, type=f32),
#[spirv(descriptor_set = 0, binding = 0, uniform)] error: &Image!(2D, type=f32),
#[spirv(descriptor_set = 0, binding = 1, uniform_constant)] warning: &Image!(2D, type=f32),
) {
}

// https://github.com/EmbarkStudios/rust-gpu/issues/585
#[spirv(vertex)]
pub fn issue_585(invalid: Image!(2D, type=f32)) {}
pub fn issue_585(#[spirv(descriptor_set = 0, binding = 0)] invalid: Image!(2D, type=f32)) {}
22 changes: 11 additions & 11 deletions tests/compiletests/ui/spirv-attr/bad-deduce-storage-class.stderr
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
error: storage class mismatch
--> $DIR/bad-deduce-storage-class.rs:8:5
|
LL | #[spirv(uniform)] error: &Image!(2D, type=f32),
| ^^^^^^^^-------^^^^^^^^^^---------------------
| | |
| | `UniformConstant` deduced from type
| `Uniform` specified in attribute
LL | #[spirv(descriptor_set = 0, binding = 0, uniform)] error: &Image!(2D, type=f32),
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^-------^^^^^^^^^^---------------------
| | |
| | `UniformConstant` deduced from type
| `Uniform` specified in attribute
|
= help: remove storage class attribute to use `UniformConstant` as storage class

warning: redundant storage class attribute, storage class is deduced from type
--> $DIR/bad-deduce-storage-class.rs:9:13
--> $DIR/bad-deduce-storage-class.rs:9:46
|
LL | #[spirv(uniform_constant)] warning: &Image!(2D, type=f32),
| ^^^^^^^^^^^^^^^^
LL | #[spirv(descriptor_set = 0, binding = 1, uniform_constant)] warning: &Image!(2D, type=f32),
| ^^^^^^^^^^^^^^^^

error: entry parameter type must be by-reference: `&spirv_std::image::Image<f32, 1, 2, 0, 0, 0, 0, 4>`
--> $DIR/bad-deduce-storage-class.rs:15:27
--> $DIR/bad-deduce-storage-class.rs:15:69
|
LL | pub fn issue_585(invalid: Image!(2D, type=f32)) {}
| ^^^^^^^^^^^^^^^^^^^^
LL | pub fn issue_585(#[spirv(descriptor_set = 0, binding = 0)] invalid: Image!(2D, type=f32)) {}
| ^^^^^^^^^^^^^^^^^^^^
|
= note: this error originates in the macro `Image` (in Nightly builds, run with -Z macro-backtrace for more info)

Expand Down
14 changes: 13 additions & 1 deletion tests/compiletests/ui/spirv-attr/bool-inputs-err.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,24 @@ error: entry-point parameter cannot contain `bool`s
LL | input: bool,
| ^^^^

error: Type not supported in Input or Output declarations
--> $DIR/bool-inputs-err.rs:13:12
|
LL | input: bool,
| ^^^^

error: entry-point parameter cannot contain `bool`s
--> $DIR/bool-inputs-err.rs:14:13
|
LL | output: &mut bool,
| ^^^^^^^^^

error: Type not supported in Input or Output declarations
--> $DIR/bool-inputs-err.rs:14:13
|
LL | output: &mut bool,
| ^^^^^^^^^

error: entry-point parameter cannot contain `bool`s
--> $DIR/bool-inputs-err.rs:15:35
|
Expand All @@ -22,5 +34,5 @@ error: entry-point parameter cannot contain `bool`s
LL | #[spirv(uniform)] uniform: &Boolthing,
| ^^^^^^^^^^

error: aborting due to 4 previous errors
error: aborting due to 6 previous errors

31 changes: 31 additions & 0 deletions tests/compiletests/ui/spirv-attr/location_assignment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// build-pass
// compile-flags: -C llvm-args=--disassemble
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
// normalize-stderr-test "; .*\n" -> ""
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
// ignore-spv1.0
// ignore-spv1.1
// ignore-spv1.2
// ignore-spv1.3
// ignore-vulkan1.0
// ignore-vulkan1.1

use spirv_std::glam::*;
use spirv_std::{Image, spirv};

#[derive(Copy, Clone, Default)]
pub struct LargerThanVec4 {
a: Vec4,
b: Vec2,
}

#[spirv(vertex)]
pub fn main(out1: &mut LargerThanVec4, out2: &mut Vec2, out3: &mut Mat4, out4: &mut f32) {
*out1 = Default::default();
*out2 = Default::default();
*out3 = Default::default();
*out4 = Default::default();
}
Loading