Skip to content

Commit d751c82

Browse files
authored
[SPIRV] Allow spirv type as template parameter (microsoft#7626)
SPIR-V intrinsics allow us to create spirv basic type and opaque type in HLSL, but these type are object and not allowed in template parameter. ```fundamental error: object 'Int8Type' is not allowed in builtin template parameters /* OpTypeCooperativeMatrixKHR */ 4456, Int8Type, ^ ``` This doesn't make sense to me, and is not convenience to use. This change wants to allow that use those in template parameter.
1 parent 4fcf67f commit d751c82

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

tools/clang/lib/Sema/SemaHLSL.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5402,6 +5402,15 @@ class HLSLExternalSource : public ExternalSemaSource {
54025402
objectKind = ClassifyRecordType(recordType);
54035403
switch (objectKind) {
54045404
case AR_TOBJ_OBJECT:
5405+
#ifdef ENABLE_SPIRV_CODEGEN
5406+
if (const auto *namespaceDecl = dyn_cast<NamespaceDecl>(
5407+
recordType->getDecl()->getDeclContext());
5408+
namespaceDecl && namespaceDecl->getName().equals("vk") &&
5409+
(recordType->getDecl()->getName().equals("SpirvType") ||
5410+
recordType->getDecl()->getName().equals("SpirvOpaqueType"))) {
5411+
return true;
5412+
}
5413+
#endif
54055414
m_sema->Diag(argLoc, diag::err_hlsl_unsupported_object_context)
54065415
<< type << static_cast<unsigned>(TypeDiagContext::TypeParameter);
54075416
return false;
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %dxc -T cs_6_8 -HV 2021 -O0 -spirv -fspv-target-env=universal1.5 %s | FileCheck %s
2+
3+
// CHECK: [[Int8Type:%.*]] = OpTypeInt 8 0
4+
using Int8Type = vk::SpirvType</* OpTypeInt */ 21, 8, 8,
5+
vk::Literal<vk::integral_constant<uint32_t, 8> >,
6+
vk::Literal<vk::integral_constant<bool, 0> > >;
7+
8+
// CHECK: [[MatrixType:%.*]] = OpTypeCooperativeMatrixKHR [[Int8Type]] %uint_3 %uint_16 %uint_16 %uint_0
9+
using I8MatA = vk::SpirvOpaqueType<
10+
/* OpTypeCooperativeMatrixKHR */ 4456, Int8Type,
11+
vk::integral_constant<uint, /* ScopeSubgroup */ 3>,
12+
vk::integral_constant<uint, 16>, vk::integral_constant<uint, 16>,
13+
vk::integral_constant<uint, /* Use */ 0> >;
14+
15+
template <typename ResultType, typename PointerType>
16+
[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType
17+
__builtin_spv_CooperativeMatrixLoadKHR([[vk::ext_reference]] PointerType pointer,
18+
uint32_t memory_layout, uint32_t stride, [[vk::ext_literal]] uint32_t memory_operand);
19+
20+
StructuredBuffer<uint32_t> buffer : register(t0, space0);
21+
22+
[numthreads(32, 1, 1)] void main() {
23+
[[vk::ext_extension("SPV_KHR_cooperative_matrix")]]
24+
[[vk::ext_capability(/* CooperativeMatrixKHRCapability */ 6022)]]
25+
[[vk::ext_capability(/* VulkanMemoryModel */ 5345)]]
26+
[[vk::ext_capability(/* Int8 */ 39)]]
27+
// CHECK: OpCooperativeMatrixLoadKHR [[MatrixType]] %{{.*}} %uint_0 %uint_32 None
28+
I8MatA matA = __builtin_spv_CooperativeMatrixLoadKHR<I8MatA>(buffer[0], /* rowMajor */ 0, 32, 0);
29+
}

0 commit comments

Comments
 (0)