Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 135 additions & 11 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6208,6 +6208,102 @@ static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1,
MRI.use_instr_nodbg_end());
}

/// Check if all uses of a multiply instruction can be contracted into FMAs.
/// This prevents creating FMAs when the multiply has other uses that can't
/// be contracted, which would duplicate the multiply computation.
static bool allMulUsesCanBeContracted(const MachineInstr &MulMI,
const MachineRegisterInfo &MRI,
bool AllowFusionGlobally) {
if (!isContractableFMul(const_cast<MachineInstr &>(MulMI),
AllowFusionGlobally))
return false;
Register MulReg = MulMI.getOperand(0).getReg();

// Check all uses of the multiply result
for (const MachineInstr &UseMI : MRI.use_nodbg_instructions(MulReg)) {
unsigned Opcode = UseMI.getOpcode();

// Direct FADD/FSUB uses
if (Opcode == TargetOpcode::G_FADD || Opcode == TargetOpcode::G_FSUB) {
// Check that we're not contracting both operands (which would duplicate)
Register Op1 = UseMI.getOperand(1).getReg();
Register Op2 = UseMI.getOperand(2).getReg();
if (Op1 == MulReg && Op2 == MulReg)
return false;

continue;
}

// FNEG → FSUB pattern
// Also handles FNEG → FPEXT → FSUB
if (Opcode == TargetOpcode::G_FNEG) {
Register FNegReg = UseMI.getOperand(0).getReg();
// ALL users of the FNEG must be contractable FSUBs or FPEXTs leading to
// FSUBs
for (const MachineInstr &FNegUseMI :
MRI.use_nodbg_instructions(FNegReg)) {
unsigned FNegUseOpcode = FNegUseMI.getOpcode();

if (FNegUseOpcode == TargetOpcode::G_FSUB) {
continue;
}
if (FNegUseOpcode == TargetOpcode::G_FPEXT) {
// FNEG → FPEXT → FSUB
Register FNegFPExtReg = FNegUseMI.getOperand(0).getReg();
for (const MachineInstr &FNegFPExtUseMI :
MRI.use_nodbg_instructions(FNegFPExtReg)) {
if (FNegFPExtUseMI.getOpcode() != TargetOpcode::G_FSUB)
return false;
}
continue;
}
return false;
}
continue;
}

// FP_EXTEND → {FADD, FSUB, FMA, FMAD} pattern
// Also handles FP_EXTEND → FNEG → FSUB
if (Opcode == TargetOpcode::G_FPEXT) {
Register FPExtReg = UseMI.getOperand(0).getReg();

// ALL users of the FP_EXTEND must be contractable operations or FNEGs
for (const MachineInstr &FPExtUseMI :
MRI.use_nodbg_instructions(FPExtReg)) {
unsigned ExtUseOpcode = FPExtUseMI.getOpcode();
if (ExtUseOpcode == TargetOpcode::G_FADD ||
ExtUseOpcode == TargetOpcode::G_FSUB ||
ExtUseOpcode == TargetOpcode::G_FMA ||
ExtUseOpcode == TargetOpcode::G_FMAD) {
continue;
}
if (ExtUseOpcode == TargetOpcode::G_FNEG) {
// FP_EXTEND → FNEG → FSUB
Register FPExtFNegReg = FPExtUseMI.getOperand(0).getReg();
for (const MachineInstr &FPExtFNegUseMI :
MRI.use_nodbg_instructions(FPExtFNegReg)) {
if (FPExtFNegUseMI.getOpcode() != TargetOpcode::G_FSUB)
return false;
}
continue;
}
return false;
}
continue;
}

// FMA/FMAD uses - assume contractable to avoid complex tracking
if (Opcode == TargetOpcode::G_FMA || Opcode == TargetOpcode::G_FMAD) {
continue;
}

// Any other use type is not contractable
return false;
}

return true;
}

bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI,
bool &AllowFusionGlobally,
bool &HasFMAD, bool &Aggressive,
Expand Down Expand Up @@ -6265,7 +6361,9 @@ bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA(

// fold (fadd (fmul x, y), z) -> (fma x, y, z)
if (isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
(Aggressive || MRI.hasOneNonDBGUse(LHS.Reg))) {
(MRI.hasOneNonDBGUse(LHS.Reg) ||
(Aggressive &&
allMulUsesCanBeContracted(*LHS.MI, MRI, AllowFusionGlobally)))) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
{LHS.MI->getOperand(1).getReg(),
Expand All @@ -6276,7 +6374,9 @@ bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA(

// fold (fadd x, (fmul y, z)) -> (fma y, z, x)
if (isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
(Aggressive || MRI.hasOneNonDBGUse(RHS.Reg))) {
(MRI.hasOneNonDBGUse(RHS.Reg) ||
(Aggressive &&
allMulUsesCanBeContracted(*RHS.MI, MRI, AllowFusionGlobally)))) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
{RHS.MI->getOperand(1).getReg(),
Expand Down Expand Up @@ -6319,6 +6419,9 @@ bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA(
MachineInstr *FpExtSrc;
if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) &&
isContractableFMul(*FpExtSrc, AllowFusionGlobally) &&
(MRI.hasOneNonDBGUse(FpExtSrc->getOperand(0).getReg()) ||
(Aggressive &&
allMulUsesCanBeContracted(*FpExtSrc, MRI, AllowFusionGlobally))) &&
TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
MRI.getType(FpExtSrc->getOperand(1).getReg()))) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
Expand All @@ -6334,6 +6437,9 @@ bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA(
// Note: Commutes FADD operands.
if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) &&
isContractableFMul(*FpExtSrc, AllowFusionGlobally) &&
(MRI.hasOneNonDBGUse(FpExtSrc->getOperand(0).getReg()) ||
(Aggressive &&
allMulUsesCanBeContracted(*FpExtSrc, MRI, AllowFusionGlobally))) &&
TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
MRI.getType(FpExtSrc->getOperand(1).getReg()))) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
Expand Down Expand Up @@ -6463,6 +6569,7 @@ bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
mi_match(LHS.MI->getOperand(3).getReg(), MRI,
m_GFPExt(m_MInstr(FMulMI))) &&
isContractableFMul(*FMulMI, AllowFusionGlobally) &&
allMulUsesCanBeContracted(*FMulMI, MRI, AllowFusionGlobally) &&
TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
MRI.getType(FMulMI->getOperand(0).getReg()))) {
MatchInfo = [=](MachineIRBuilder &B) {
Expand All @@ -6483,6 +6590,7 @@ bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
FMAMI->getOpcode() == PreferredFusedOpcode) {
MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg());
if (isContractableFMul(*FMulMI, AllowFusionGlobally) &&
allMulUsesCanBeContracted(*FMulMI, MRI, AllowFusionGlobally) &&
TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
MRI.getType(FMAMI->getOperand(0).getReg()))) {
MatchInfo = [=](MachineIRBuilder &B) {
Expand All @@ -6504,6 +6612,7 @@ bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
mi_match(RHS.MI->getOperand(3).getReg(), MRI,
m_GFPExt(m_MInstr(FMulMI))) &&
isContractableFMul(*FMulMI, AllowFusionGlobally) &&
allMulUsesCanBeContracted(*FMulMI, MRI, AllowFusionGlobally) &&
TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
MRI.getType(FMulMI->getOperand(0).getReg()))) {
MatchInfo = [=](MachineIRBuilder &B) {
Expand All @@ -6524,6 +6633,7 @@ bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
FMAMI->getOpcode() == PreferredFusedOpcode) {
MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg());
if (isContractableFMul(*FMulMI, AllowFusionGlobally) &&
allMulUsesCanBeContracted(*FMulMI, MRI, AllowFusionGlobally) &&
TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
MRI.getType(FMAMI->getOperand(0).getReg()))) {
MatchInfo = [=](MachineIRBuilder &B) {
Expand Down Expand Up @@ -6570,7 +6680,9 @@ bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA(
// fold (fsub (fmul x, y), z) -> (fma x, y, -z)
if (FirstMulHasFewerUses &&
(isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
(Aggressive || MRI.hasOneNonDBGUse(LHS.Reg)))) {
(MRI.hasOneNonDBGUse(LHS.Reg) ||
(Aggressive &&
allMulUsesCanBeContracted(*LHS.MI, MRI, AllowFusionGlobally))))) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
Register NegZ = B.buildFNeg(DstTy, RHS.Reg).getReg(0);
B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
Expand All @@ -6580,8 +6692,10 @@ bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA(
return true;
}
// fold (fsub x, (fmul y, z)) -> (fma -y, z, x)
else if ((isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
(Aggressive || MRI.hasOneNonDBGUse(RHS.Reg)))) {
if (isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
(MRI.hasOneNonDBGUse(RHS.Reg) ||
(Aggressive &&
allMulUsesCanBeContracted(*RHS.MI, MRI, AllowFusionGlobally)))) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
Register NegY =
B.buildFNeg(DstTy, RHS.MI->getOperand(1).getReg()).getReg(0);
Expand Down Expand Up @@ -6613,8 +6727,10 @@ bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA(
MachineInstr *FMulMI;
// fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z))
if (mi_match(LHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) &&
(Aggressive || (MRI.hasOneNonDBGUse(LHSReg) &&
MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) &&
((MRI.hasOneNonDBGUse(LHSReg) &&
MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg())) ||
(Aggressive &&
allMulUsesCanBeContracted(*FMulMI, MRI, AllowFusionGlobally))) &&
isContractableFMul(*FMulMI, AllowFusionGlobally)) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
Register NegX =
Expand All @@ -6628,8 +6744,10 @@ bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA(

// fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x)
if (mi_match(RHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) &&
(Aggressive || (MRI.hasOneNonDBGUse(RHSReg) &&
MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) &&
((MRI.hasOneNonDBGUse(RHSReg) &&
MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg())) ||
(Aggressive &&
allMulUsesCanBeContracted(*FMulMI, MRI, AllowFusionGlobally))) &&
isContractableFMul(*FMulMI, AllowFusionGlobally)) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
Expand Down Expand Up @@ -6662,7 +6780,9 @@ bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
// fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
if (mi_match(LHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) &&
isContractableFMul(*FMulMI, AllowFusionGlobally) &&
(Aggressive || MRI.hasOneNonDBGUse(LHSReg))) {
(MRI.hasOneNonDBGUse(LHSReg) ||
(Aggressive &&
allMulUsesCanBeContracted(*FMulMI, MRI, AllowFusionGlobally)))) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
Register FpExtX =
B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
Expand All @@ -6678,7 +6798,9 @@ bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
// fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
if (mi_match(RHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) &&
isContractableFMul(*FMulMI, AllowFusionGlobally) &&
(Aggressive || MRI.hasOneNonDBGUse(RHSReg))) {
(MRI.hasOneNonDBGUse(RHSReg) ||
(Aggressive &&
allMulUsesCanBeContracted(*FMulMI, MRI, AllowFusionGlobally)))) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
Register FpExtY =
B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
Expand Down Expand Up @@ -6726,6 +6848,7 @@ bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
if ((mi_match(LHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) ||
mi_match(LHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) &&
isContractableFMul(*FMulMI, AllowFusionGlobally) &&
allMulUsesCanBeContracted(*FMulMI, MRI, AllowFusionGlobally) &&
TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy,
MRI.getType(FMulMI->getOperand(0).getReg()))) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
Expand All @@ -6742,6 +6865,7 @@ bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
if ((mi_match(RHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) ||
mi_match(RHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) &&
isContractableFMul(*FMulMI, AllowFusionGlobally) &&
allMulUsesCanBeContracted(*FMulMI, MRI, AllowFusionGlobally) &&
TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy,
MRI.getType(FMulMI->getOperand(0).getReg()))) {
MatchInfo = [=, &MI](MachineIRBuilder &B) {
Expand Down
Loading