Skip to content

Commit 283365c

Browse files
authored
[SPIR-V] Fix precision for dot2add (microsoft#7861)
Fixes microsoft#7695 (part of [offload test suite](https://github.com/llvm/offload-test-suite/blob/2e266dae318b2ce38bfd0633bf52fe33ca127bd2/test/Feature/HLSLLib/dot2add.test#L96)) [HLSL spec](https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/hlsl-shader-model-6-4-features-for-direct3d-12#single-precision-floating-point-2-element-dot-product-and-accumulate) indicates that the elements are mutilplied with `half-precision` but the summation results in a `float`. [OpDot](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDot) requires the `ResultType` to be the same as the vector's `ComponentType`, so this opcode cannot be used. The fix is to untangle `OpDot` -- multiply `half2` vectors and convert them to `float` before summing the elements.
1 parent 784a62e commit 283365c

File tree

2 files changed

+35
-18
lines changed

2 files changed

+35
-18
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13165,18 +13165,26 @@ SpirvInstruction *SpirvEmitter::processIntrinsicDP2a(const CallExpr *callExpr) {
1316513165
SpirvInstruction *arg1Instr = doExpr(arg1);
1316613166
SpirvInstruction *arg2Instr = doExpr(arg2);
1316713167

13168-
// Create the dot product of the half2 vectors.
13169-
SpirvInstruction *dotInstr = spvBuilder.createBinaryOp(
13170-
spv::Op::OpDot, componentType, arg0Instr, arg1Instr, loc, range);
13168+
// Multiply the two half2 vectors and convert the result to float2.
13169+
SpirvInstruction *mulInstr = spvBuilder.createBinaryOp(
13170+
spv::Op::OpFMul, vecType, arg0Instr, arg1Instr, loc, range);
13171+
SpirvInstruction *convertInstr = spvBuilder.createUnaryOp(
13172+
spv::Op::OpFConvert,
13173+
astContext.getExtVectorType(astContext.FloatTy, vecSize), mulInstr, loc,
13174+
range);
1317113175

13172-
// Convert dot product (half type) to result type (float).
13173-
QualType resultType = callExpr->getType();
13174-
SpirvInstruction *floatDotInstr = spvBuilder.createUnaryOp(
13175-
spv::Op::OpFConvert, resultType, dotInstr, loc, range);
13176+
// Extract each float element and and sum them up.
13177+
SpirvInstruction *extractedElem0 = spvBuilder.createCompositeExtract(
13178+
astContext.FloatTy, convertInstr, {0}, loc, range);
13179+
SpirvInstruction *extractedElem1 = spvBuilder.createCompositeExtract(
13180+
astContext.FloatTy, convertInstr, {1}, loc, range);
13181+
SpirvInstruction *dotInstr =
13182+
spvBuilder.createBinaryOp(spv::Op::OpFAdd, astContext.FloatTy,
13183+
extractedElem0, extractedElem1, loc, range);
1317613184

1317713185
// Sum the dot product result and accumulator and return.
13178-
return spvBuilder.createBinaryOp(spv::Op::OpFAdd, resultType, floatDotInstr,
13179-
arg2Instr, loc, range);
13186+
return spvBuilder.createBinaryOp(spv::Op::OpFAdd, astContext.FloatTy,
13187+
dotInstr, arg2Instr, loc, range);
1318013188
}
1318113189

1318213190
SpirvInstruction *

tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@ float main() : SV_Target {
1212
// CHECK: [[input1A:%[0-9]+]] = OpLoad %v2half %input1A
1313
// CHECK: [[input1B:%[0-9]+]] = OpLoad %v2half %input1B
1414
// CHECK: [[acc1:%[0-9]+]] = OpLoad %float %acc1
15-
// CHECK: [[dot1_0:%[0-9]+]] = OpDot %half [[input1A]] [[input1B]]
16-
// CHECK: [[dot1:%[0-9]+]] = OpFConvert %float [[dot1_0]]
17-
// CHECK: [[res1:%[0-9]+]] = OpFAdd %float [[dot1]] [[acc1]]
15+
// CHECK: [[mult1A:%[0-9]+]] = OpFMul %v2half [[input1A]] [[input1B]]
16+
// CHECK: [[convert1A:%[0-9]+]] = OpFConvert %v2float [[mult1A]]
17+
// CHECK: [[extract1A_0:%[0-9]+]] = OpCompositeExtract %float [[convert1A]] 0
18+
// CHECK: [[extract1A_1:%[0-9]+]] = OpCompositeExtract %float [[convert1A]] 1
19+
// CHECK: [[add1A:%[0-9]+]] = OpFAdd %float [[extract1A_0]] [[extract1A_1]]
20+
// CHECK: [[res1:%[0-9]+]] = OpFAdd %float [[add1A]] [[acc1]]
1821
res += dot2add(input1A, input1B, acc1);
1922

2023
half4 input2;
@@ -25,9 +28,12 @@ float main() : SV_Target {
2528
// CHECK: [[input2B:%[0-9]+]] = OpVectorShuffle %v2half [[input2_1]] [[input2_1]] 2 3
2629
// CHECK: [[acc2_0:%[0-9]+]] = OpLoad %int %acc2
2730
// CHECK: [[acc2:%[0-9]+]] = OpConvertSToF %float [[acc2_0]]
28-
// CHECK: [[dot2_0:%[0-9]+]] = OpDot %half [[input2A]] [[input2B]]
29-
// CHECK: [[dot2:%[0-9]+]] = OpFConvert %float [[dot2_0]]
30-
// CHECK: [[res2:%[0-9]+]] = OpFAdd %float [[dot2]] [[acc2]]
31+
// CHECK: [[mult2A:%[0-9]+]] = OpFMul %v2half [[input2A]] [[input2B]]
32+
// CHECK: [[convert2A:%[0-9]+]] = OpFConvert %v2float [[mult2A]]
33+
// CHECK: [[extract2A_0:%[0-9]+]] = OpCompositeExtract %float [[convert2A]] 0
34+
// CHECK: [[extract2A_1:%[0-9]+]] = OpCompositeExtract %float [[convert2A]] 1
35+
// CHECK: [[add2A:%[0-9]+]] = OpFAdd %float [[extract2A_0]] [[extract2A_1]]
36+
// CHECK: [[res2:%[0-9]+]] = OpFAdd %float [[add2A]] [[acc2]]
3137
res += dot2add(input2.xy, input2.zw, acc2);
3238

3339
float input3A;
@@ -44,9 +50,12 @@ float main() : SV_Target {
4450
// CHECK: [[input3B:%[0-9]+]] = OpConvertSToF %v2half [[input3B_5]]
4551
// CHECK: [[acc3_1:%[0-9]+]] = OpLoad %half %acc3
4652
// CHECK: [[acc3:%[0-9]+]] = OpFConvert %float [[acc3_1]]
47-
// CHECK: [[dot3_1:%[0-9]+]] = OpDot %half [[input3A]] [[input3B]]
48-
// CHECK: [[dot3:%[0-9]+]] = OpFConvert %float [[dot3_1]]
49-
// CHECK: [[res3:%[0-9]+]] = OpFAdd %float [[dot3]] [[acc3]]
53+
// CHECK: [[mult3A:%[0-9]+]] = OpFMul %v2half [[input3A]] [[input3B]]
54+
// CHECK: [[convert3A:%[0-9]+]] = OpFConvert %v2float [[mult3A]]
55+
// CHECK: [[extract3A_0:%[0-9]+]] = OpCompositeExtract %float [[convert3A]] 0
56+
// CHECK: [[extract3A_1:%[0-9]+]] = OpCompositeExtract %float [[convert3A]] 1
57+
// CHECK: [[add3A:%[0-9]+]] = OpFAdd %float [[extract3A_0]] [[extract3A_1]]
58+
// CHECK: [[res3:%[0-9]+]] = OpFAdd %float [[add3A]] [[acc3]]
5059
res += dot2add(input3A, input3B, acc3);
5160

5261
return res;

0 commit comments

Comments
 (0)