Skip to content

Commit aacd739

Browse files
committed
location assignment: explicit location assignment with #[spirv(location = 0)]
1 parent cc77930 commit aacd739

File tree

7 files changed

+244
-15
lines changed

7 files changed

+244
-15
lines changed

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ pub enum SpirvAttribute {
9494
Builtin(BuiltIn),
9595
DescriptorSet(u32),
9696
Binding(u32),
97+
Location(u32),
9798
Flat,
9899
PerPrimitiveExt,
99100
Invariant,
@@ -130,6 +131,7 @@ pub struct AggregatedSpirvAttributes {
130131
pub builtin: Option<Spanned<BuiltIn>>,
131132
pub descriptor_set: Option<Spanned<u32>>,
132133
pub binding: Option<Spanned<u32>>,
134+
pub location: Option<Spanned<u32>>,
133135
pub flat: Option<Spanned<()>>,
134136
pub invariant: Option<Spanned<()>>,
135137
pub per_primitive_ext: Option<Spanned<()>>,
@@ -216,6 +218,7 @@ impl AggregatedSpirvAttributes {
216218
"#[spirv(descriptor_set)]",
217219
),
218220
Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
221+
Location(value) => try_insert(&mut self.location, value, span, "#[spirv(location)]"),
219222
Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
220223
Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
221224
PerPrimitiveExt => try_insert(
@@ -323,6 +326,7 @@ impl CheckSpirvAttrVisitor<'_> {
323326
| SpirvAttribute::Builtin(_)
324327
| SpirvAttribute::DescriptorSet(_)
325328
| SpirvAttribute::Binding(_)
329+
| SpirvAttribute::Location(_)
326330
| SpirvAttribute::Flat
327331
| SpirvAttribute::Invariant
328332
| SpirvAttribute::PerPrimitiveExt
@@ -602,6 +606,8 @@ fn parse_spirv_attr<'a>(
602606
SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
603607
} else if arg.has_name(sym.binding) {
604608
SpirvAttribute::Binding(parse_attr_int_value(arg)?)
609+
} else if arg.has_name(sym.location) {
610+
SpirvAttribute::Location(parse_attr_int_value(arg)?)
605611
} else if arg.has_name(sym.input_attachment_index) {
606612
SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
607613
} else if arg.has_name(sym.spec_constant) {

crates/rustc_codegen_spirv/src/codegen_cx/entry.rs

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,26 @@ impl<'tcx> CodegenCx<'tcx> {
710710
.name(var_id.or(spec_const_id).unwrap(), ident.to_string());
711711
}
712712

713+
// location assignment
714+
let has_location = matches!(
715+
storage_class,
716+
Ok(StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant)
717+
);
718+
let mut assign_location = |var_id: Result<Word, &str>, explicit: Option<u32>| {
719+
let location = decoration_locations
720+
.entry(storage_class.unwrap())
721+
.or_insert_with(|| 0);
722+
if let Some(explicit) = explicit {
723+
*location = explicit;
724+
}
725+
self.emit_global().decorate(
726+
var_id.unwrap(),
727+
Decoration::Location,
728+
std::iter::once(Operand::LiteralBit32(*location)),
729+
);
730+
*location += value_layout.size.bytes().div_ceil(16) as u32;
731+
};
732+
713733
// Emit `OpDecorate`s based on attributes.
714734
let mut decoration_supersedes_location = false;
715735
if let Some(builtin) = attrs.builtin {
@@ -757,6 +777,35 @@ impl<'tcx> CodegenCx<'tcx> {
757777
);
758778
decoration_supersedes_location = true;
759779
}
780+
if let Some(location) = attrs.location {
781+
if let Err(SpecConstant { .. }) = storage_class {
782+
self.tcx.dcx().span_fatal(
783+
location.span,
784+
"`#[spirv(location = ...)]` cannot apply to `#[spirv(spec_constant)]`",
785+
);
786+
}
787+
if attrs.descriptor_set.is_some() {
788+
self.tcx.dcx().span_fatal(
789+
location.span,
790+
"`#[spirv(location = ...)]` cannot be combined with `#[spirv(descriptor_set = ...)]`",
791+
);
792+
}
793+
if attrs.binding.is_some() {
794+
self.tcx.dcx().span_fatal(
795+
location.span,
796+
"`#[spirv(location = ...)]` cannot be combined with `#[spirv(binding = ...)]`",
797+
);
798+
}
799+
if !has_location {
800+
self.tcx.dcx().span_fatal(
801+
location.span,
802+
"`#[spirv(location = ...)]` can only be used on Inputs (declared as plain values, eg. `Vec4`)\
803+
or Outputs (declared as mut ref, eg. `&mut Vec4`)",
804+
);
805+
}
806+
assign_location(var_id, Some(location.value));
807+
decoration_supersedes_location = true;
808+
}
760809
if let Some(flat) = attrs.flat {
761810
if let Err(SpecConstant { .. }) = storage_class {
762811
self.tcx.dcx().span_fatal(
@@ -867,21 +916,8 @@ impl<'tcx> CodegenCx<'tcx> {
867916
// individually.
868917
// TODO: Is this right for UniformConstant? Do they share locations with
869918
// input/outpus?
870-
let has_location = !decoration_supersedes_location
871-
&& matches!(
872-
storage_class,
873-
Ok(StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant)
874-
);
875-
if has_location {
876-
let location = decoration_locations
877-
.entry(storage_class.unwrap())
878-
.or_insert_with(|| 0);
879-
self.emit_global().decorate(
880-
var_id.unwrap(),
881-
Decoration::Location,
882-
std::iter::once(Operand::LiteralBit32(*location)),
883-
);
884-
*location += value_layout.size.bytes().div_ceil(16) as u32;
919+
if !decoration_supersedes_location && has_location {
920+
assign_location(var_id, None);
885921
}
886922

887923
match storage_class {

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub struct Symbols {
2323

2424
pub descriptor_set: Symbol,
2525
pub binding: Symbol,
26+
pub location: Symbol,
2627
pub input_attachment_index: Symbol,
2728

2829
pub spec_constant: Symbol,
@@ -420,6 +421,7 @@ impl Symbols {
420421

421422
descriptor_set: Symbol::intern("descriptor_set"),
422423
binding: Symbol::intern("binding"),
424+
location: Symbol::intern("location"),
423425
input_attachment_index: Symbol::intern("input_attachment_index"),
424426

425427
spec_constant: Symbol::intern("spec_constant"),
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// build-pass
2+
// compile-flags: -C llvm-args=--disassemble
3+
// normalize-stderr-test "OpSource .*\n" -> ""
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
6+
// normalize-stderr-test "; .*\n" -> ""
7+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
8+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
9+
// ignore-spv1.0
10+
// ignore-spv1.1
11+
// ignore-spv1.2
12+
// ignore-spv1.3
13+
// ignore-vulkan1.0
14+
// ignore-vulkan1.1
15+
16+
use spirv_std::glam::*;
17+
use spirv_std::{Image, spirv};
18+
19+
#[derive(Copy, Clone, Default)]
20+
pub struct LargerThanVec4 {
21+
a: Vec4,
22+
b: Vec2,
23+
}
24+
25+
#[spirv(vertex)]
26+
pub fn main(
27+
#[spirv(location = 4)] out1: &mut LargerThanVec4,
28+
out2: &mut Vec2, // should be 6
29+
#[spirv(location = 0)] out3: &mut Mat4,
30+
// 8 to 11 are unused, that's fine
31+
#[spirv(location = 12)] out4: &mut f32,
32+
) {
33+
*out1 = Default::default();
34+
*out2 = Default::default();
35+
*out3 = Default::default();
36+
*out4 = Default::default();
37+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
OpCapability Shader
2+
OpMemoryModel Logical Simple
3+
OpEntryPoint Vertex %1 "main" %2 %3 %4 %5
4+
OpName %7 "LargerThanVec4"
5+
OpMemberName %7 0 "a"
6+
OpMemberName %7 1 "b"
7+
OpName %8 "spirv_std::glam::Mat4"
8+
OpMemberName %8 0 "x_axis"
9+
OpMemberName %8 1 "y_axis"
10+
OpMemberName %8 2 "z_axis"
11+
OpMemberName %8 3 "w_axis"
12+
OpName %9 "LargerThanVec4"
13+
OpMemberName %9 0 "a"
14+
OpMemberName %9 1 "b"
15+
OpName %2 "out1"
16+
OpName %3 "out2"
17+
OpName %4 "out3"
18+
OpName %5 "out4"
19+
OpMemberDecorate %9 0 Offset 0
20+
OpMemberDecorate %9 1 Offset 16
21+
OpDecorate %2 Location 4
22+
OpDecorate %3 Location 6
23+
OpDecorate %4 Location 0
24+
OpDecorate %5 Location 12
25+
%10 = OpTypeFloat 32
26+
%11 = OpTypeVector %10 4
27+
%12 = OpTypeVector %10 2
28+
%7 = OpTypeStruct %11 %12
29+
%13 = OpTypePointer Output %7
30+
%14 = OpTypePointer Output %12
31+
%8 = OpTypeStruct %11 %11 %11 %11
32+
%15 = OpTypePointer Output %8
33+
%16 = OpTypePointer Output %10
34+
%17 = OpTypeVoid
35+
%18 = OpTypeFunction %17
36+
%9 = OpTypeStruct %11 %12
37+
%19 = OpConstant %10 0
38+
%20 = OpConstantComposite %11 %19 %19 %19 %19
39+
%21 = OpUndef %9
40+
%2 = OpVariable %13 Output
41+
%3 = OpVariable %14 Output
42+
%22 = OpTypeInt 32 0
43+
%23 = OpConstant %22 0
44+
%24 = OpConstant %22 1
45+
%4 = OpVariable %15 Output
46+
%25 = OpConstant %10 1
47+
%26 = OpConstantComposite %11 %25 %19 %19 %19
48+
%27 = OpConstantComposite %11 %19 %25 %19 %19
49+
%28 = OpConstantComposite %11 %19 %19 %25 %19
50+
%29 = OpConstantComposite %11 %19 %19 %19 %25
51+
%30 = OpConstantComposite %8 %26 %27 %28 %29
52+
%5 = OpVariable %16 Output
53+
%1 = OpFunction %17 None %18
54+
%31 = OpLabel
55+
%32 = OpCompositeInsert %9 %20 %21 0
56+
%33 = OpCompositeInsert %9 %19 %32 1 0
57+
%34 = OpCompositeInsert %9 %19 %33 1 1
58+
%35 = OpCompositeExtract %11 %34 0
59+
%36 = OpCompositeExtract %12 %34 1
60+
%37 = OpCompositeConstruct %7 %35 %36
61+
OpStore %2 %37
62+
%38 = OpInBoundsAccessChain %16 %3 %23
63+
OpStore %38 %19
64+
%39 = OpInBoundsAccessChain %16 %3 %24
65+
OpStore %39 %19
66+
OpStore %4 %30
67+
OpStore %5 %19
68+
OpNoLine
69+
OpReturn
70+
OpFunctionEnd
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// build-fail
2+
// compile-flags: -C llvm-args=--disassemble
3+
// normalize-stderr-test "OpSource .*\n" -> ""
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
6+
// normalize-stderr-test "; .*\n" -> ""
7+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
8+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
9+
// normalize-stderr-test "= note: module `.*`" -> "= note: module `<normalized>`"
10+
// ignore-spv1.0
11+
// ignore-spv1.1
12+
// ignore-spv1.2
13+
// ignore-spv1.3
14+
// ignore-spv1.4
15+
// ignore-spv1.5
16+
// ignore-spv1.6
17+
// ignore-vulkan1.0
18+
// ignore-vulkan1.1
19+
20+
use spirv_std::glam::*;
21+
use spirv_std::{Image, spirv};
22+
23+
#[spirv(vertex)]
24+
pub fn main(#[spirv(location = 0)] out1: &mut Mat4, #[spirv(location = 1)] out2: &mut Vec2) {
25+
*out1 = Default::default();
26+
*out2 = Default::default();
27+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
OpCapability Shader
2+
OpMemoryModel Logical Simple
3+
OpEntryPoint Vertex %1 "main" %2 %3
4+
OpName %5 "spirv_std::glam::Mat4"
5+
OpMemberName %5 0 "x_axis"
6+
OpMemberName %5 1 "y_axis"
7+
OpMemberName %5 2 "z_axis"
8+
OpMemberName %5 3 "w_axis"
9+
OpName %2 "out1"
10+
OpName %3 "out2"
11+
OpDecorate %2 Location 0
12+
OpDecorate %3 Location 1
13+
%6 = OpTypeFloat 32
14+
%7 = OpTypeVector %6 4
15+
%5 = OpTypeStruct %7 %7 %7 %7
16+
%8 = OpTypePointer Output %5
17+
%9 = OpTypeVector %6 2
18+
%10 = OpTypePointer Output %9
19+
%11 = OpTypeVoid
20+
%12 = OpTypeFunction %11
21+
%2 = OpVariable %8 Output
22+
%13 = OpConstant %6 1
23+
%14 = OpConstant %6 0
24+
%15 = OpConstantComposite %7 %13 %14 %14 %14
25+
%16 = OpConstantComposite %7 %14 %13 %14 %14
26+
%17 = OpConstantComposite %7 %14 %14 %13 %14
27+
%18 = OpConstantComposite %7 %14 %14 %14 %13
28+
%19 = OpConstantComposite %5 %15 %16 %17 %18
29+
%20 = OpTypePointer Output %6
30+
%3 = OpVariable %10 Output
31+
%21 = OpTypeInt 32 0
32+
%22 = OpConstant %21 0
33+
%23 = OpConstant %21 1
34+
%1 = OpFunction %11 None %12
35+
%24 = OpLabel
36+
OpStore %2 %19
37+
%25 = OpInBoundsAccessChain %20 %3 %22
38+
OpStore %25 %14
39+
%26 = OpInBoundsAccessChain %20 %3 %23
40+
OpStore %26 %14
41+
OpNoLine
42+
OpReturn
43+
OpFunctionEnd
44+
error: error:0:0 - [VUID-StandaloneSpirv-OpEntryPoint-08722] Entry-point has conflicting output location assignment at location 1, component 0
45+
OpEntryPoint Vertex %1 "main" %out1 %out2
46+
|
47+
= note: spirv-val failed
48+
= note: module `<normalized>`
49+
50+
error: aborting due to 1 previous error
51+

0 commit comments

Comments
 (0)