@@ -438,7 +438,7 @@ namespace {
438438 SDValue visitSUBE(SDNode *N);
439439 SDValue visitUSUBO_CARRY(SDNode *N);
440440 SDValue visitSSUBO_CARRY(SDNode *N);
441- SDValue visitMUL(SDNode *N);
441+ template <class MatchContextClass> SDValue visitMUL(SDNode *N);
442442 SDValue visitMULFIX(SDNode *N);
443443 SDValue useDivRem(SDNode *N);
444444 SDValue visitSDIV(SDNode *N);
@@ -1855,7 +1855,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
18551855 case ISD::SMULFIXSAT:
18561856 case ISD::UMULFIX:
18571857 case ISD::UMULFIXSAT: return visitMULFIX(N);
1858- case ISD::MUL: return visitMUL(N);
1858+ case ISD::MUL: return visitMUL<EmptyMatchContext> (N);
18591859 case ISD::SDIV: return visitSDIV(N);
18601860 case ISD::UDIV: return visitUDIV(N);
18611861 case ISD::SREM:
@@ -4331,11 +4331,13 @@ SDValue DAGCombiner::visitMULFIX(SDNode *N) {
43314331 return SDValue();
43324332}
43334333
4334- SDValue DAGCombiner::visitMUL(SDNode *N) {
4334+ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
43354335 SDValue N0 = N->getOperand(0);
43364336 SDValue N1 = N->getOperand(1);
43374337 EVT VT = N0.getValueType();
43384338 SDLoc DL(N);
4339+ bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
4340+ MatchContextClass Matcher(DAG, TLI, N);
43394341
43404342 // fold (mul x, undef) -> 0
43414343 if (N0.isUndef() || N1.isUndef())
@@ -4348,16 +4350,18 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
43484350 // canonicalize constant to RHS (vector doesn't have to splat)
43494351 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
43504352 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4351- return DAG .getNode(ISD::MUL, DL, VT, N1, N0);
4353+ return Matcher .getNode(ISD::MUL, DL, VT, N1, N0);
43524354
43534355 bool N1IsConst = false;
43544356 bool N1IsOpaqueConst = false;
43554357 APInt ConstValue1;
43564358
43574359 // fold vector ops
43584360 if (VT.isVector()) {
4359- if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4360- return FoldedVOp;
4361+ // TODO: Change this to use SimplifyVBinOp when it supports VP op.
4362+ if (!UseVP)
4363+ if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4364+ return FoldedVOp;
43614365
43624366 N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
43634367 assert((!N1IsConst ||
@@ -4379,20 +4383,21 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
43794383 if (N1IsConst && ConstValue1.isOne())
43804384 return N0;
43814385
4382- if (SDValue NewSel = foldBinOpIntoSelect(N))
4383- return NewSel;
4386+ if (!UseVP)
4387+ if (SDValue NewSel = foldBinOpIntoSelect(N))
4388+ return NewSel;
43844389
43854390 // fold (mul x, -1) -> 0-x
43864391 if (N1IsConst && ConstValue1.isAllOnes())
4387- return DAG.getNegative(N0 , DL, VT);
4392+ return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0 , DL, VT), N0 );
43884393
43894394 // fold (mul x, (1 << c)) -> x << c
43904395 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
43914396 (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
43924397 if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
43934398 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
43944399 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4395- return DAG .getNode(ISD::SHL, DL, VT, N0, Trunc);
4400+ return Matcher .getNode(ISD::SHL, DL, VT, N0, Trunc);
43964401 }
43974402 }
43984403
@@ -4403,24 +4408,26 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44034408
44044409 // FIXME: If the input is something that is easily negated (e.g. a
44054410 // single-use add), we should put the negate there.
4406- return DAG .getNode(ISD::SUB, DL, VT,
4407- DAG.getConstant(0, DL, VT),
4408- DAG .getNode(ISD::SHL, DL, VT, N0,
4409- DAG.getConstant(Log2Val, DL, ShiftVT)));
4411+ return Matcher .getNode(
4412+ ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
4413+ Matcher .getNode(ISD::SHL, DL, VT, N0,
4414+ DAG.getConstant(Log2Val, DL, ShiftVT)));
44104415 }
44114416
44124417 // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
44134418 // hi result is in use in case we hit this mid-legalization.
4414- for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4415- if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4416- SDVTList LoHiVT = DAG.getVTList(VT, VT);
4417- // TODO: Can we match commutable operands with getNodeIfExists?
4418- if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4419- if (LoHi->hasAnyUseOfValue(1))
4420- return SDValue(LoHi, 0);
4421- if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4422- if (LoHi->hasAnyUseOfValue(1))
4423- return SDValue(LoHi, 0);
4419+ if (!UseVP) {
4420+ for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4421+ if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4422+ SDVTList LoHiVT = DAG.getVTList(VT, VT);
4423+ // TODO: Can we match commutable operands with getNodeIfExists?
4424+ if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4425+ if (LoHi->hasAnyUseOfValue(1))
4426+ return SDValue(LoHi, 0);
4427+ if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4428+ if (LoHi->hasAnyUseOfValue(1))
4429+ return SDValue(LoHi, 0);
4430+ }
44244431 }
44254432 }
44264433
@@ -4439,7 +4446,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44394446 // x * 0xf800 --> (x << 16) - (x << 11)
44404447 // x * -0x8800 --> -((x << 15) + (x << 11))
44414448 // x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4442- if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4449+ if (!UseVP && N1IsConst &&
4450+ TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
44434451 // TODO: We could handle more general decomposition of any constant by
44444452 // having the target set a limit on number of ops and making a
44454453 // callback to determine that sequence (similar to sqrt expansion).
@@ -4473,7 +4481,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44734481 }
44744482
44754483 // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4476- if (N0.getOpcode() == ISD::SHL) {
4484+ if (sd_context_match(N0, Matcher, m_Opc( ISD::SHL)) ) {
44774485 SDValue N01 = N0.getOperand(1);
44784486 if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
44794487 return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
@@ -4485,42 +4493,41 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44854493 SDValue Sh, Y;
44864494
44874495 // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4488- if (N0.getOpcode() == ISD::SHL &&
4489- isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse() ) {
4496+ if (sd_context_match(N0, Matcher, m_OneUse(m_Opc( ISD::SHL))) &&
4497+ isConstantOrConstantVector(N0.getOperand(1))) {
44904498 Sh = N0; Y = N1;
4491- } else if (N1.getOpcode() == ISD::SHL &&
4492- isConstantOrConstantVector(N1.getOperand(1)) &&
4493- N1->hasOneUse()) {
4499+ } else if (sd_context_match(N1, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4500+ isConstantOrConstantVector(N1.getOperand(1))) {
44944501 Sh = N1; Y = N0;
44954502 }
44964503
44974504 if (Sh.getNode()) {
4498- SDValue Mul = DAG .getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4499- return DAG .getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4505+ SDValue Mul = Matcher .getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4506+ return Matcher .getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
45004507 }
45014508 }
45024509
45034510 // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4504- if (N0.getOpcode() == ISD::ADD &&
4511+ if (sd_context_match(N0, Matcher, m_Opc( ISD::ADD)) &&
45054512 DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
45064513 DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
45074514 isMulAddWithConstProfitable(N, N0, N1))
4508- return DAG .getNode(
4515+ return Matcher .getNode(
45094516 ISD::ADD, DL, VT,
4510- DAG .getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4511- DAG .getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4517+ Matcher .getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4518+ Matcher .getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
45124519
45134520 // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
45144521 ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4515- if (N0.getOpcode() == ISD::VSCALE && NC1) {
4522+ if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
45164523 const APInt &C0 = N0.getConstantOperandAPInt(0);
45174524 const APInt &C1 = NC1->getAPIntValue();
45184525 return DAG.getVScale(DL, VT, C0 * C1);
45194526 }
45204527
45214528 // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
45224529 APInt MulVal;
4523- if (N0.getOpcode() == ISD::STEP_VECTOR &&
4530+ if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
45244531 ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
45254532 const APInt &C0 = N0.getConstantOperandAPInt(0);
45264533 APInt NewStep = C0 * MulVal;
@@ -4558,13 +4565,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
45584565 }
45594566
45604567 // reassociate mul
4561- if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4562- return RMUL;
4568+ // TODO: Change reassociateOps to support vp ops.
4569+ if (!UseVP)
4570+ if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4571+ return RMUL;
45634572
45644573 // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4565- if (SDValue SD =
4566- reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4567- return SD;
4574+ // TODO: Change reassociateReduction to support vp ops.
4575+ if (!UseVP)
4576+ if (SDValue SD =
4577+ reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4578+ return SD;
45684579
45694580 // Simplify the operands using demanded-bits information.
45704581 if (SimplifyDemandedBits(SDValue(N, 0)))
@@ -26693,6 +26704,10 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
2669326704 return visitFMA<VPMatchContext>(N);
2669426705 case ISD::VP_SELECT:
2669526706 return visitVP_SELECT(N);
26707+ case ISD::VP_MUL:
26708+ return visitMUL<VPMatchContext>(N);
26709+ default:
26710+ break;
2669626711 }
2669726712 return SDValue();
2669826713 }
@@ -27850,6 +27865,10 @@ static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
2785027865 if (!VT.isVector())
2785127866 return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
2785227867 // We need to create a build vector
27868+ if (Op.getOpcode() == ISD::SPLAT_VECTOR)
27869+ return DAG.getSplat(VT, DL,
27870+ DAG.getConstant(Pow2Constants.back().logBase2(), DL,
27871+ VT.getScalarType()));
2785327872 SmallVector<SDValue> Log2Ops;
2785427873 for (const APInt &Pow2 : Pow2Constants)
2785527874 Log2Ops.emplace_back(
0 commit comments