diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index e08c7f6bb3c49..5f42295f0fe21 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -1027,6 +1027,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM, ISD::SCALAR_TO_VECTOR, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND_INREG, + ISD::ANY_EXTEND, ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT, ISD::FCOPYSIGN}); @@ -13289,6 +13290,20 @@ static uint32_t getPermuteMask(SDValue V) { return ~0; } +static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI); + +SDValue SITargetLowering::performLeftShiftCombine(SDNode *N, + DAGCombinerInfo &DCI) const { + if (DCI.getDAGCombineLevel() < AfterLegalizeTypes) + return SDValue(); + + EVT VT = N->getValueType(0); + if (VT != MVT::i32) + return SDValue(); + + return matchPERM(N, DCI); +} + SDValue SITargetLowering::performAndCombine(SDNode *N, DAGCombinerInfo &DCI) const { if (DCI.isBeforeLegalize()) @@ -13532,6 +13547,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1); } + case ISD::ANY_EXTEND: case ISD::SIGN_EXTEND: case ISD::ZERO_EXTEND: case ISD::SIGN_EXTEND_INREG: { @@ -13874,6 +13890,9 @@ static bool addresses16Bits(int Mask) { int Low8 = Mask & 0xff; int Hi8 = (Mask & 0xff00) >> 8; + if (Hi8 == 0x0c || Low8 == 0x0c) + return false; + assert(Low8 < 8 && Hi8 < 8); // Are the bytes contiguous in the order of increasing addresses. bool IsConsecutive = (Hi8 - Low8 == 1); @@ -13968,8 +13987,10 @@ static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src, static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SelectionDAG &DAG = DCI.DAG; + assert(!DAG.getDataLayout().isBigEndian()); + [[maybe_unused]] EVT VT = N->getValueType(0); - SmallVector, 8> PermNodes; + SmallVector, 4> PermNodes; // VT is known to be MVT::i32, so we need to provide 4 bytes. assert(VT == MVT::i32); @@ -13977,66 +13998,95 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { // Find the ByteProvider that provides the ith byte of the result of OR std::optional> P = calculateByteProvider(SDValue(N, 0), i, 0, /*StartingIndex = */ i); - // TODO support constantZero - if (!P || P->isConstantZero()) + if (!P) return SDValue(); PermNodes.push_back(*P); } - if (PermNodes.size() != 4) - return SDValue(); - std::pair FirstSrc(0, PermNodes[0].SrcOffset / 4); - std::optional> SecondSrc; + static auto isSameSrc = [](SDValue SrcA, unsigned DWordA, SDValue SrcB, + unsigned DWordB) { + // If the Src uses a byte from a different DWORD, then it corresponds + // with a difference source + return SrcA == SrcB && DWordA == DWordB; + }; + + SDValue Src0, Src1; + unsigned DWord0, DWord1; uint64_t PermMask = 0x00000000; for (size_t i = 0; i < PermNodes.size(); i++) { - auto PermOp = PermNodes[i]; - // Since the mask is applied to Src1:Src2, Src1 bytes must be offset - // by sizeof(Src2) = 4 - int SrcByteAdjust = 4; + ByteProvider PermOp = PermNodes[i]; + if (PermOp.isConstantZero()) { + PermMask |= 0x0c << (i * 8); + continue; + } - // If the Src uses a byte from a different DWORD, then it corresponds - // with a difference source - if (!PermOp.hasSameSrc(PermNodes[FirstSrc.first]) || - ((PermOp.SrcOffset / 4) != FirstSrc.second)) { - if (SecondSrc) - if (!PermOp.hasSameSrc(PermNodes[SecondSrc->first]) || - ((PermOp.SrcOffset / 4) != SecondSrc->second)) - return SDValue(); + const SDValue SrcI = PermOp.Src.value(); + const unsigned DWordI = PermOp.SrcOffset / 4; + const unsigned ByteI = PermOp.SrcOffset % 4; + if (!Src0) { + Src0 = SrcI; + DWord0 = DWordI; + } - // Set the index of the second distinct Src node - SecondSrc = {i, PermNodes[i].SrcOffset / 4}; - assert(!(PermNodes[SecondSrc->first].Src->getValueSizeInBits() % 8)); - SrcByteAdjust = 0; + if (!isSameSrc(Src0, DWord0, SrcI, DWordI)) { + if (!Src1) { + Src1 = SrcI; + DWord1 = DWordI; + } else if (!isSameSrc(Src1, DWord1, SrcI, DWordI)) + return SDValue(); } - assert((PermOp.SrcOffset % 4) + SrcByteAdjust < 8); - assert(!DAG.getDataLayout().isBigEndian()); - PermMask |= ((PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8); + + // Since the mask is applied to Src0:Src1, Src0 bytes must be offset + // by sizeof(Src1) = 4 + const int SrcByteAdjust = SrcI == Src0 ? 4 : 0; + assert(ByteI + SrcByteAdjust < 8); + PermMask |= (ByteI + SrcByteAdjust) << (i * 8); } + SDLoc DL(N); - SDValue Op = *PermNodes[FirstSrc.first].Src; - Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second); + SDValue Op = Src0; + Op = getDWordFromOffset(DAG, DL, Op, DWord0); assert(Op.getValueSizeInBits() == 32); // Check that we are not just extracting the bytes in order from an op - if (!SecondSrc) { - int Low16 = PermMask & 0xffff; - int Hi16 = (PermMask & 0xffff0000) >> 16; + if (!Src1) { + unsigned LeftShift = 0; + unsigned Expected = 0x0c; + int I = 0; + for (; I < 4; ++I) { + unsigned Sel = 0xFF & (PermMask >> (I * 8)); + if (Expected == 0x0c && Sel == 0x0c) { + LeftShift += 8; + continue; + } + if (Expected == 0x0c) + Expected = 4; + if (Sel != Expected) + break; + ++Expected; + } + if (I == 4) { + if (LeftShift == 0) + return DAG.getBitcast(MVT::getIntegerVT(32), Op); - bool WellFormedLow = (Low16 == 0x0504) || (Low16 == 0x0100); - bool WellFormedHi = (Hi16 == 0x0706) || (Hi16 == 0x0302); + if (N->getOpcode() == ISD::SHL) { + auto *ShiftOp = dyn_cast(N->getOperand(1)); + if (ShiftOp && ShiftOp->getZExtValue() == LeftShift) + return SDValue(); + } - // The perm op would really just produce Op. So combine into Op - if (WellFormedLow && WellFormedHi) - return DAG.getBitcast(MVT::getIntegerVT(32), Op); + return DAG.getNode(ISD::SHL, DL, MVT::i32, Src0, + DAG.getConstant(LeftShift, DL, MVT::i32)); + } } - SDValue OtherOp = SecondSrc ? *PermNodes[SecondSrc->first].Src : Op; - - if (SecondSrc) { - OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second); + SDValue OtherOp; + if (Src1) { + OtherOp = getDWordFromOffset(DAG, DL, Src1, DWord1); assert(OtherOp.getValueSizeInBits() == 32); - } + } else + OtherOp = Op; if (hasNon16BitAccesses(PermMask, Op, OtherOp)) { @@ -14315,10 +14365,11 @@ SDValue SITargetLowering::performXorCombine(SDNode *N, return SDValue(); } -SDValue SITargetLowering::performZeroExtendCombine(SDNode *N, - DAGCombinerInfo &DCI) const { +SDValue +SITargetLowering::performZeroOrAnyExtendCombine(SDNode *N, + DAGCombinerInfo &DCI) const { if (!Subtarget->has16BitInsts() || - DCI.getDAGCombineLevel() < AfterLegalizeDAG) + DCI.getDAGCombineLevel() < AfterLegalizeTypes) return SDValue(); EVT VT = N->getValueType(0); @@ -14329,7 +14380,41 @@ SDValue SITargetLowering::performZeroExtendCombine(SDNode *N, if (Src.getValueType() != MVT::i16) return SDValue(); - return SDValue(); + // TODO: We bail out below if SrcOffset is not in the first dword (>= 4). It's + // possible we're missing out on some combine opportunities, but we'd need to + // weigh the cost of extracting the byte from the upper dwords. + + std::optional> BP0 = + calculateByteProvider(SDValue(N, 0), 0, 0, 0); + if (!BP0.has_value() || 4 <= BP0->SrcOffset) + return SDValue(); + SDValue V0 = BP0->Src.value_or(SDValue()); + + std::optional> BP1 = + calculateByteProvider(SDValue(N, 0), 1, 0, 1); + if (!BP1.has_value() || 4 <= BP1->SrcOffset) + return SDValue(); + SDValue V1 = BP1->Src.value_or(SDValue()); + + if (!V0 || !V1 || V0 == V1) + return SDValue(); + + SelectionDAG &DAG = DCI.DAG; + SDLoc DL(N); + uint32_t PermMask = 0x0c0c0c0c; + if (V0) { + V0 = DAG.getBitcastedAnyExtOrTrunc(V0, DL, MVT::i32); + PermMask = (PermMask & ~0xFF) | (BP0->SrcOffset + 4); + } + + if (V1) { + V1 = DAG.getBitcastedAnyExtOrTrunc(V1, DL, MVT::i32); + PermMask = (PermMask & ~(0xFF << 8)) | (BP1->SrcOffset << 8); + } + + SDValue P = DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, V0, V1, + DAG.getConstant(PermMask, DL, MVT::i32)); + return P; } SDValue @@ -16997,6 +17082,12 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N, return performMinMaxCombine(N, DCI); case ISD::FMA: return performFMACombine(N, DCI); + + case ISD::SHL: + if (auto Res = performLeftShiftCombine(N, DCI)) + return Res; + break; + case ISD::AND: return performAndCombine(N, DCI); case ISD::OR: @@ -17011,8 +17102,9 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N, } case ISD::XOR: return performXorCombine(N, DCI); + case ISD::ANY_EXTEND: case ISD::ZERO_EXTEND: - return performZeroExtendCombine(N, DCI); + return performZeroOrAnyExtendCombine(N, DCI); case ISD::SIGN_EXTEND_INREG: return performSignExtendInRegCombine(N, DCI); case AMDGPUISD::FP_CLASS: diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h index 74e58f4272e10..e61563cfaa964 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.h +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h @@ -202,10 +202,11 @@ class SITargetLowering final : public AMDGPUTargetLowering { unsigned Opc, SDValue LHS, const ConstantSDNode *CRHS) const; + SDValue performLeftShiftCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performAndCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performOrCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performXorCombine(SDNode *N, DAGCombinerInfo &DCI) const; - SDValue performZeroExtendCombine(SDNode *N, DAGCombinerInfo &DCI) const; + SDValue performZeroOrAnyExtendCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performSignExtendInRegCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performClassCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue getCanonicalConstantFP(SelectionDAG &DAG, const SDLoc &SL, EVT VT,