Skip to content

Commit e48d49f

Browse files
authored
[RISCV][llvm] Support VFADD, VFSUB, VFMUL codegen for Zvfbfa (#170612)
Support both fixed-length vectors and scalable vectors. Note: VP version is not gonna be supported for trivial instructions since they're going to be removed soon.
1 parent af4098b commit e48d49f

12 files changed

+2031
-492
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ static cl::opt<bool>
9090
// TODO: Support more ops
9191
static const unsigned ZvfbfaVPOps[] = {
9292
ISD::VP_FNEG, ISD::VP_FABS, ISD::VP_FCOPYSIGN, ISD::EXPERIMENTAL_VP_SPLAT};
93-
static const unsigned ZvfbfaOps[] = {ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN,
94-
ISD::SPLAT_VECTOR};
93+
static const unsigned ZvfbfaOps[] = {
94+
ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN, ISD::SPLAT_VECTOR,
95+
ISD::FADD, ISD::FSUB, ISD::FMUL};
9596

9697
RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
9798
const RISCVSubtarget &STI)
@@ -1091,6 +1092,36 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
10911092
ISD::VECREDUCE_FMINIMUM,
10921093
ISD::VECREDUCE_FMAXIMUM};
10931094

1095+
// TODO: Make more of these ops legal.
1096+
static const unsigned ZvfbfaPromoteOps[] = {ISD::FMINNUM,
1097+
ISD::FMAXNUM,
1098+
ISD::FMINIMUMNUM,
1099+
ISD::FMAXIMUMNUM,
1100+
ISD::FDIV,
1101+
ISD::FMA,
1102+
ISD::FSQRT,
1103+
ISD::FCEIL,
1104+
ISD::FTRUNC,
1105+
ISD::FFLOOR,
1106+
ISD::FROUND,
1107+
ISD::FROUNDEVEN,
1108+
ISD::FRINT,
1109+
ISD::FNEARBYINT,
1110+
ISD::IS_FPCLASS,
1111+
ISD::SETCC,
1112+
ISD::FMAXIMUM,
1113+
ISD::FMINIMUM,
1114+
ISD::STRICT_FADD,
1115+
ISD::STRICT_FSUB,
1116+
ISD::STRICT_FMUL,
1117+
ISD::STRICT_FDIV,
1118+
ISD::STRICT_FSQRT,
1119+
ISD::STRICT_FMA,
1120+
ISD::VECREDUCE_FMIN,
1121+
ISD::VECREDUCE_FMAX,
1122+
ISD::VECREDUCE_FMINIMUM,
1123+
ISD::VECREDUCE_FMAXIMUM};
1124+
10941125
// TODO: support more vp ops.
10951126
static const unsigned ZvfhminZvfbfminPromoteVPOps[] = {
10961127
ISD::VP_FADD,
@@ -1295,11 +1326,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
12951326

12961327
// Custom split nxv32[b]f16 since nxv32[b]f32 is not legal.
12971328
if (getLMUL(VT) == RISCVVType::LMUL_8) {
1298-
setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom);
1329+
setOperationAction(ZvfbfaPromoteOps, VT, Custom);
12991330
setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom);
13001331
} else {
13011332
MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
1302-
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
1333+
setOperationPromotedToType(ZvfbfaPromoteOps, VT, F32VecVT);
13031334
setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT);
13041335
}
13051336
};
@@ -1616,7 +1647,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
16161647
// TODO: could split the f16 vector into two vectors and do promotion.
16171648
if (!isTypeLegal(F32VecVT))
16181649
continue;
1619-
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
1650+
1651+
if (Subtarget.hasStdExtZvfbfa())
1652+
setOperationPromotedToType(ZvfbfaPromoteOps, VT, F32VecVT);
1653+
else
1654+
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
16201655
setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT);
16211656
continue;
16221657
}

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ multiclass VPatBinaryFPSDNode_VV_VF<SDPatternOperator vop, string instruction_na
215215
}
216216

217217
multiclass VPatBinaryFPSDNode_VV_VF_RM<SDPatternOperator vop, string instruction_name,
218-
bit isSEWAware = 0> {
219-
foreach vti = AllFloatVectors in {
218+
bit isSEWAware = 0, bit isBF16 = 0> {
219+
foreach vti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in {
220220
let Predicates = GetVTypePredicates<vti>.Predicates in {
221221
def : VPatBinarySDNode_VV_RM<vop, instruction_name,
222222
vti.Vector, vti.Vector, vti.Log2SEW,
@@ -246,8 +246,8 @@ multiclass VPatBinaryFPSDNode_R_VF<SDPatternOperator vop, string instruction_nam
246246
}
247247

248248
multiclass VPatBinaryFPSDNode_R_VF_RM<SDPatternOperator vop, string instruction_name,
249-
bit isSEWAware = 0> {
250-
foreach fvti = AllFloatVectors in
249+
bit isSEWAware = 0, bit isBF16 = 0> {
250+
foreach fvti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in
251251
let Predicates = GetVTypePredicates<fvti>.Predicates in
252252
def : Pat<(fvti.Vector (vop (fvti.Vector (SplatFPOp fvti.Scalar:$rs2)),
253253
(fvti.Vector fvti.RegClass:$rs1))),

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,8 +1058,8 @@ multiclass VPatBinaryFPVL_VV_VF<SDPatternOperator vop, string instruction_name,
10581058
}
10591059

10601060
multiclass VPatBinaryFPVL_VV_VF_RM<SDPatternOperator vop, string instruction_name,
1061-
bit isSEWAware = 0> {
1062-
foreach vti = AllFloatVectors in {
1061+
bit isSEWAware = 0, bit isBF16 = 0> {
1062+
foreach vti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in {
10631063
let Predicates = GetVTypePredicates<vti>.Predicates in {
10641064
def : VPatBinaryVL_V_RM<vop, instruction_name, "VV",
10651065
vti.Vector, vti.Vector, vti.Vector, vti.Mask,
@@ -1093,8 +1093,8 @@ multiclass VPatBinaryFPVL_R_VF<SDPatternOperator vop, string instruction_name,
10931093
}
10941094

10951095
multiclass VPatBinaryFPVL_R_VF_RM<SDPatternOperator vop, string instruction_name,
1096-
bit isSEWAware = 0> {
1097-
foreach fvti = AllFloatVectors in {
1096+
bit isSEWAware = 0, bit isBF16 = 0> {
1097+
foreach fvti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in {
10981098
let Predicates = GetVTypePredicates<fvti>.Predicates in
10991099
def : Pat<(fvti.Vector (vop (SplatFPOp fvti.ScalarRegClass:$rs2),
11001100
fvti.RegClass:$rs1,

llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,4 +783,22 @@ let Predicates = [HasStdExtZvfbfa] in {
783783
TAIL_AGNOSTIC)>;
784784
}
785785
}
786+
787+
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fadd, "PseudoVFADD_ALT",
788+
isSEWAware=1, isBF16=1>;
789+
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fsub, "PseudoVFSUB_ALT",
790+
isSEWAware=1, isBF16=1>;
791+
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fmul, "PseudoVFMUL_ALT",
792+
isSEWAware=1, isBF16=1>;
793+
defm : VPatBinaryFPSDNode_R_VF_RM<any_fsub, "PseudoVFRSUB_ALT",
794+
isSEWAware=1, isBF16=1>;
795+
796+
defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fadd_vl, "PseudoVFADD_ALT",
797+
isSEWAware=1, isBF16=1>;
798+
defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fsub_vl, "PseudoVFSUB_ALT",
799+
isSEWAware=1, isBF16=1>;
800+
defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fmul_vl, "PseudoVFMUL_ALT",
801+
isSEWAware=1, isBF16=1>;
802+
defm : VPatBinaryFPVL_R_VF_RM<any_riscv_fsub_vl, "PseudoVFRSUB_ALT",
803+
isSEWAware=1, isBF16=1>;
786804
} // Predicates = [HasStdExtZvfbfa]
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=riscv32 -mattr=+experimental-zvfbfa,+v \
3+
; RUN: -verify-machineinstrs < %s | FileCheck %s
4+
; RUN: llc -mtriple=riscv64 -mattr=+experimental-zvfbfa,+v \
5+
; RUN: -verify-machineinstrs < %s | FileCheck %s
6+
7+
define <1 x bfloat> @vfadd_vv_v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb) {
8+
; CHECK-LABEL: vfadd_vv_v1bf16:
9+
; CHECK: # %bb.0:
10+
; CHECK-NEXT: vsetivli zero, 1, e16alt, mf4, ta, ma
11+
; CHECK-NEXT: vfadd.vv v8, v8, v9
12+
; CHECK-NEXT: ret
13+
%vc = fadd <1 x bfloat> %va, %vb
14+
ret <1 x bfloat> %vc
15+
}
16+
17+
define <1 x bfloat> @vfadd_vf_v1bf16(<1 x bfloat> %va, bfloat %b) {
18+
; CHECK-LABEL: vfadd_vf_v1bf16:
19+
; CHECK: # %bb.0:
20+
; CHECK-NEXT: vsetivli zero, 1, e16alt, mf4, ta, ma
21+
; CHECK-NEXT: vfadd.vf v8, v8, fa0
22+
; CHECK-NEXT: ret
23+
%head = insertelement <1 x bfloat> poison, bfloat %b, i32 0
24+
%splat = shufflevector <1 x bfloat> %head, <1 x bfloat> poison, <1 x i32> zeroinitializer
25+
%vc = fadd <1 x bfloat> %va, %splat
26+
ret <1 x bfloat> %vc
27+
}
28+
29+
define <2 x bfloat> @vfadd_vv_v2bf16(<2 x bfloat> %va, <2 x bfloat> %vb) {
30+
; CHECK-LABEL: vfadd_vv_v2bf16:
31+
; CHECK: # %bb.0:
32+
; CHECK-NEXT: vsetivli zero, 2, e16alt, mf4, ta, ma
33+
; CHECK-NEXT: vfadd.vv v8, v8, v9
34+
; CHECK-NEXT: ret
35+
%vc = fadd <2 x bfloat> %va, %vb
36+
ret <2 x bfloat> %vc
37+
}
38+
39+
define <2 x bfloat> @vfadd_vf_v2bf16(<2 x bfloat> %va, bfloat %b) {
40+
; CHECK-LABEL: vfadd_vf_v2bf16:
41+
; CHECK: # %bb.0:
42+
; CHECK-NEXT: vsetivli zero, 2, e16alt, mf4, ta, ma
43+
; CHECK-NEXT: vfadd.vf v8, v8, fa0
44+
; CHECK-NEXT: ret
45+
%head = insertelement <2 x bfloat> poison, bfloat %b, i32 0
46+
%splat = shufflevector <2 x bfloat> %head, <2 x bfloat> poison, <2 x i32> zeroinitializer
47+
%vc = fadd <2 x bfloat> %va, %splat
48+
ret <2 x bfloat> %vc
49+
}
50+
51+
define <4 x bfloat> @vfadd_vv_v4bf16(<4 x bfloat> %va, <4 x bfloat> %vb) {
52+
; CHECK-LABEL: vfadd_vv_v4bf16:
53+
; CHECK: # %bb.0:
54+
; CHECK-NEXT: vsetivli zero, 4, e16alt, mf2, ta, ma
55+
; CHECK-NEXT: vfadd.vv v8, v8, v9
56+
; CHECK-NEXT: ret
57+
%vc = fadd <4 x bfloat> %va, %vb
58+
ret <4 x bfloat> %vc
59+
}
60+
61+
define <4 x bfloat> @vfadd_vf_v4bf16(<4 x bfloat> %va, bfloat %b) {
62+
; CHECK-LABEL: vfadd_vf_v4bf16:
63+
; CHECK: # %bb.0:
64+
; CHECK-NEXT: vsetivli zero, 4, e16alt, mf2, ta, ma
65+
; CHECK-NEXT: vfadd.vf v8, v8, fa0
66+
; CHECK-NEXT: ret
67+
%head = insertelement <4 x bfloat> poison, bfloat %b, i32 0
68+
%splat = shufflevector <4 x bfloat> %head, <4 x bfloat> poison, <4 x i32> zeroinitializer
69+
%vc = fadd <4 x bfloat> %va, %splat
70+
ret <4 x bfloat> %vc
71+
}
72+
73+
define <8 x bfloat> @vfadd_vv_v8bf16(<8 x bfloat> %va, <8 x bfloat> %vb) {
74+
; CHECK-LABEL: vfadd_vv_v8bf16:
75+
; CHECK: # %bb.0:
76+
; CHECK-NEXT: vsetivli zero, 8, e16alt, m1, ta, ma
77+
; CHECK-NEXT: vfadd.vv v8, v8, v9
78+
; CHECK-NEXT: ret
79+
%vc = fadd <8 x bfloat> %va, %vb
80+
ret <8 x bfloat> %vc
81+
}
82+
83+
define <8 x bfloat> @vfadd_vf_v8bf16(<8 x bfloat> %va, bfloat %b) {
84+
; CHECK-LABEL: vfadd_vf_v8bf16:
85+
; CHECK: # %bb.0:
86+
; CHECK-NEXT: vsetivli zero, 8, e16alt, m1, ta, ma
87+
; CHECK-NEXT: vfadd.vf v8, v8, fa0
88+
; CHECK-NEXT: ret
89+
%head = insertelement <8 x bfloat> poison, bfloat %b, i32 0
90+
%splat = shufflevector <8 x bfloat> %head, <8 x bfloat> poison, <8 x i32> zeroinitializer
91+
%vc = fadd <8 x bfloat> %va, %splat
92+
ret <8 x bfloat> %vc
93+
}
94+
95+
define <16 x bfloat> @vfadd_vv_v16bf16(<16 x bfloat> %va, <16 x bfloat> %vb) {
96+
; CHECK-LABEL: vfadd_vv_v16bf16:
97+
; CHECK: # %bb.0:
98+
; CHECK-NEXT: vsetivli zero, 16, e16alt, m2, ta, ma
99+
; CHECK-NEXT: vfadd.vv v8, v8, v10
100+
; CHECK-NEXT: ret
101+
%vc = fadd <16 x bfloat> %va, %vb
102+
ret <16 x bfloat> %vc
103+
}
104+
105+
define <16 x bfloat> @vfadd_vf_v16bf16(<16 x bfloat> %va, bfloat %b) {
106+
; CHECK-LABEL: vfadd_vf_v16bf16:
107+
; CHECK: # %bb.0:
108+
; CHECK-NEXT: vsetivli zero, 16, e16alt, m2, ta, ma
109+
; CHECK-NEXT: vfadd.vf v8, v8, fa0
110+
; CHECK-NEXT: ret
111+
%head = insertelement <16 x bfloat> poison, bfloat %b, i32 0
112+
%splat = shufflevector <16 x bfloat> %head, <16 x bfloat> poison, <16 x i32> zeroinitializer
113+
%vc = fadd <16 x bfloat> %va, %splat
114+
ret <16 x bfloat> %vc
115+
}
116+
117+
define <32 x bfloat> @vfadd_vv_v32bf16(<32 x bfloat> %va, <32 x bfloat> %vb) {
118+
; CHECK-LABEL: vfadd_vv_v32bf16:
119+
; CHECK: # %bb.0:
120+
; CHECK-NEXT: li a0, 32
121+
; CHECK-NEXT: vsetvli zero, a0, e16alt, m4, ta, ma
122+
; CHECK-NEXT: vfadd.vv v8, v8, v12
123+
; CHECK-NEXT: ret
124+
%vc = fadd <32 x bfloat> %va, %vb
125+
ret <32 x bfloat> %vc
126+
}
127+
128+
define <32 x bfloat> @vfadd_vf_v32bf16(<32 x bfloat> %va, bfloat %b) {
129+
; CHECK-LABEL: vfadd_vf_v32bf16:
130+
; CHECK: # %bb.0:
131+
; CHECK-NEXT: li a0, 32
132+
; CHECK-NEXT: vsetvli zero, a0, e16alt, m4, ta, ma
133+
; CHECK-NEXT: vfadd.vf v8, v8, fa0
134+
; CHECK-NEXT: ret
135+
%head = insertelement <32 x bfloat> poison, bfloat %b, i32 0
136+
%splat = shufflevector <32 x bfloat> %head, <32 x bfloat> poison, <32 x i32> zeroinitializer
137+
%vc = fadd <32 x bfloat> %va, %splat
138+
ret <32 x bfloat> %vc
139+
}
140+
141+
define <64 x bfloat> @vfadd_vv_v64bf16(<64 x bfloat> %va, <64 x bfloat> %vb) {
142+
; CHECK-LABEL: vfadd_vv_v64bf16:
143+
; CHECK: # %bb.0:
144+
; CHECK-NEXT: li a0, 64
145+
; CHECK-NEXT: vsetvli zero, a0, e16alt, m8, ta, ma
146+
; CHECK-NEXT: vfadd.vv v8, v8, v16
147+
; CHECK-NEXT: ret
148+
%vc = fadd <64 x bfloat> %va, %vb
149+
ret <64 x bfloat> %vc
150+
}
151+
152+
define <64 x bfloat> @vfadd_vf_v64bf16(<64 x bfloat> %va, bfloat %b) {
153+
; CHECK-LABEL: vfadd_vf_v64bf16:
154+
; CHECK: # %bb.0:
155+
; CHECK-NEXT: li a0, 64
156+
; CHECK-NEXT: vsetvli zero, a0, e16alt, m8, ta, ma
157+
; CHECK-NEXT: vfadd.vf v8, v8, fa0
158+
; CHECK-NEXT: ret
159+
%head = insertelement <64 x bfloat> poison, bfloat %b, i32 0
160+
%splat = shufflevector <64 x bfloat> %head, <64 x bfloat> poison, <64 x i32> zeroinitializer
161+
%vc = fadd <64 x bfloat> %va, %splat
162+
ret <64 x bfloat> %vc
163+
}

0 commit comments

Comments
 (0)