|
25 | 25 | #define DEBUG_TYPE "vplan" |
26 | 26 |
|
27 | 27 | using namespace llvm; |
| 28 | +using namespace VPlanPatternMatch; |
28 | 29 |
|
29 | 30 | namespace { |
30 | 31 | // Class that is used to build the plain CFG for the incoming IR. |
@@ -427,7 +428,6 @@ static void createLoopRegion(VPlan &Plan, VPBlockBase *HeaderVPB) { |
427 | 428 | static void addCanonicalIVRecipes(VPlan &Plan, VPBasicBlock *HeaderVPBB, |
428 | 429 | VPBasicBlock *LatchVPBB, Type *IdxTy, |
429 | 430 | DebugLoc DL) { |
430 | | - using namespace VPlanPatternMatch; |
431 | 431 | Value *StartIdx = ConstantInt::get(IdxTy, 0); |
432 | 432 | auto *StartV = Plan.getOrAddLiveIn(StartIdx); |
433 | 433 |
|
@@ -628,3 +628,114 @@ void VPlanTransforms::attachCheckBlock(VPlan &Plan, Value *Cond, |
628 | 628 | Term->addMetadata(LLVMContext::MD_prof, BranchWeights); |
629 | 629 | } |
630 | 630 | } |
| 631 | + |
| 632 | +bool VPlanTransforms::handleFMaxReductionsWithoutFastMath(VPlan &Plan) { |
| 633 | + VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion(); |
| 634 | + VPReductionPHIRecipe *RedPhiR = nullptr; |
| 635 | + VPRecipeWithIRFlags *MinMaxOp = nullptr; |
| 636 | + VPWidenIntOrFpInductionRecipe *WideIV = nullptr; |
| 637 | + |
| 638 | + // Check if there are any FMaxNoFMFs reductions using wide selects that we can |
| 639 | + // fix up. To do so, we also need a wide canonical IV to keep track of the |
| 640 | + // indices of the max values. |
| 641 | + for (auto &R : LoopRegion->getEntryBasicBlock()->phis()) { |
| 642 | + // We need a wide canonical IV |
| 643 | + if (auto *CurIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&R)) { |
| 644 | + if (!CurIV->isCanonical()) |
| 645 | + continue; |
| 646 | + WideIV = CurIV; |
| 647 | + continue; |
| 648 | + } |
| 649 | + |
| 650 | + // And a single FMaxNoFMFs reduction phi. |
| 651 | + // TODO: Support FMin reductions as well. |
| 652 | + auto *CurRedPhiR = dyn_cast<VPReductionPHIRecipe>(&R); |
| 653 | + if (!CurRedPhiR) |
| 654 | + continue; |
| 655 | + if (RedPhiR) |
| 656 | + return false; |
| 657 | + if (CurRedPhiR->getRecurrenceKind() != RecurKind::FMaxNoFMFs || |
| 658 | + CurRedPhiR->isInLoop() || CurRedPhiR->isOrdered()) |
| 659 | + continue; |
| 660 | + RedPhiR = CurRedPhiR; |
| 661 | + |
| 662 | + // MaxOp feeding the reduction phi must be a select (either wide or a |
| 663 | + // replicate recipe), where the phi is the last operand, and the compare |
| 664 | + // predicate is strict. This ensures NaNs won't get propagated unless the |
| 665 | + // initial value is NaN |
| 666 | + VPRecipeBase *Inc = RedPhiR->getBackedgeValue()->getDefiningRecipe(); |
| 667 | + auto *RepR = dyn_cast<VPReplicateRecipe>(Inc); |
| 668 | + if (!isa<VPWidenSelectRecipe>(Inc) && |
| 669 | + !(RepR && (isa<SelectInst>(RepR->getUnderlyingInstr())))) |
| 670 | + return false; |
| 671 | + |
| 672 | + MinMaxOp = cast<VPRecipeWithIRFlags>(Inc); |
| 673 | + auto *Cmp = cast<VPRecipeWithIRFlags>(MinMaxOp->getOperand(0)); |
| 674 | + if (MinMaxOp->getOperand(1) == RedPhiR || |
| 675 | + !CmpInst::isStrictPredicate(Cmp->getPredicate())) |
| 676 | + return false; |
| 677 | + } |
| 678 | + |
| 679 | + // Nothing to do. |
| 680 | + if (!RedPhiR) |
| 681 | + return true; |
| 682 | + |
| 683 | + // A wide canonical IV is currently required. |
| 684 | + // TODO: Create an induction if no suitable existing one is available. |
| 685 | + if (!WideIV) |
| 686 | + return false; |
| 687 | + |
| 688 | + // Create a reduction that tracks the first indices where the latest maximum |
| 689 | + // value has been selected. This is later used to select the max value from |
| 690 | + // the partial reductions in a way that correctly handles signed zeros and |
| 691 | + // NaNs in the input. |
| 692 | + // Note that we do not need to check if the induction may hit the sentinel |
| 693 | + // value. If the sentinel value gets hit, the final reduction value is at the |
| 694 | + // last index or the maximum was never set and all lanes contain the start |
| 695 | + // value. In either case, the correct value is selected. |
| 696 | + unsigned IVWidth = |
| 697 | + VPTypeAnalysis(Plan).inferScalarType(WideIV)->getScalarSizeInBits(); |
| 698 | + LLVMContext &Ctx = Plan.getScalarHeader()->getIRBasicBlock()->getContext(); |
| 699 | + VPValue *UMinSentinel = |
| 700 | + Plan.getOrAddLiveIn(ConstantInt::get(Ctx, APInt::getMaxValue(IVWidth))); |
| 701 | + auto *IdxPhi = new VPReductionPHIRecipe(nullptr, RecurKind::FindFirstIVUMin, |
| 702 | + *UMinSentinel, false, false, 1); |
| 703 | + IdxPhi->insertBefore(RedPhiR); |
| 704 | + auto *MinIdxSel = new VPInstruction( |
| 705 | + Instruction::Select, {MinMaxOp->getOperand(0), WideIV, IdxPhi}); |
| 706 | + MinIdxSel->insertAfter(MinMaxOp); |
| 707 | + IdxPhi->addOperand(MinIdxSel); |
| 708 | + |
| 709 | + // Find the first index of with the maximum value. This is used to extract the |
| 710 | + // lane with the final max value and is needed to handle signed zeros and NaNs |
| 711 | + // in the input. |
| 712 | + auto *MiddleVPBB = Plan.getMiddleBlock(); |
| 713 | + auto *OrigRdxResult = cast<VPSingleDefRecipe>(&MiddleVPBB->front()); |
| 714 | + VPBuilder Builder(OrigRdxResult->getParent(), |
| 715 | + std::next(OrigRdxResult->getIterator())); |
| 716 | + |
| 717 | + // Create mask for lanes that have the max value and use it to mask out |
| 718 | + // indices that don't contain maximum values. |
| 719 | + auto *MaskFinalMaxValue = Builder.createNaryOp( |
| 720 | + Instruction::FCmp, {OrigRdxResult->getOperand(1), OrigRdxResult}, |
| 721 | + VPIRFlags(CmpInst::FCMP_OEQ)); |
| 722 | + auto *IndicesWithMaxValue = Builder.createNaryOp( |
| 723 | + Instruction::Select, {MaskFinalMaxValue, MinIdxSel, UMinSentinel}); |
| 724 | + auto *FirstMaxIdx = Builder.createNaryOp( |
| 725 | + VPInstruction::ComputeFindIVResult, |
| 726 | + {IdxPhi, WideIV->getStartValue(), UMinSentinel, IndicesWithMaxValue}); |
| 727 | + // Convert the index of the first max value to an index in the vector lanes of |
| 728 | + // the partial reduction results. This ensures we select the first max value |
| 729 | + // and acts as a tie-breaker if the partial reductions contain signed zeros. |
| 730 | + auto *FirstMaxLane = |
| 731 | + Builder.createNaryOp(Instruction::URem, {FirstMaxIdx, &Plan.getVFxUF()}); |
| 732 | + |
| 733 | + // Extract the final max value and update the users. |
| 734 | + auto *Res = Builder.createNaryOp( |
| 735 | + VPInstruction::ExtractLane, {FirstMaxLane, OrigRdxResult->getOperand(1)}); |
| 736 | + OrigRdxResult->replaceUsesWithIf(Res, |
| 737 | + [MaskFinalMaxValue](VPUser &U, unsigned) { |
| 738 | + return &U != MaskFinalMaxValue; |
| 739 | + }); |
| 740 | + return true; |
| 741 | +} |
0 commit comments