[SimplifyCFG] FoldTwoEntryPHINode(): consider *total* speculation cost, not per-BB...
[llvm-complete.git] / lib / Transforms / Vectorize / VPlanPredicator.cpp
blob7a80f3ff80a5c3bb48d03a2953fc320d804dc629
1 //===-- VPlanPredicator.cpp -------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the VPlanPredicator class which contains the public
11 /// interfaces to predicate and linearize the VPlan region.
12 ///
13 //===----------------------------------------------------------------------===//
15 #include "VPlanPredicator.h"
16 #include "VPlan.h"
17 #include "llvm/ADT/DepthFirstIterator.h"
18 #include "llvm/ADT/GraphTraits.h"
19 #include "llvm/ADT/PostOrderIterator.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
23 #define DEBUG_TYPE "VPlanPredicator"
25 using namespace llvm;
27 // Generate VPInstructions at the beginning of CurrBB that calculate the
28 // predicate being propagated from PredBB to CurrBB depending on the edge type
29 // between them. For example if:
30 // i. PredBB is controlled by predicate %BP, and
31 // ii. The edge PredBB->CurrBB is the false edge, controlled by the condition
32 // bit value %CBV then this function will generate the following two
33 // VPInstructions at the start of CurrBB:
34 // %IntermediateVal = not %CBV
35 // %FinalVal = and %BP %IntermediateVal
36 // It returns %FinalVal.
37 VPValue *VPlanPredicator::getOrCreateNotPredicate(VPBasicBlock *PredBB,
38 VPBasicBlock *CurrBB) {
39 VPValue *CBV = PredBB->getCondBit();
41 // Set the intermediate value - this is either 'CBV', or 'not CBV'
42 // depending on the edge type.
43 EdgeType ET = getEdgeTypeBetween(PredBB, CurrBB);
44 VPValue *IntermediateVal = nullptr;
45 switch (ET) {
46 case EdgeType::TRUE_EDGE:
47 // CurrBB is the true successor of PredBB - nothing to do here.
48 IntermediateVal = CBV;
49 break;
51 case EdgeType::FALSE_EDGE:
52 // CurrBB is the False successor of PredBB - compute not of CBV.
53 IntermediateVal = Builder.createNot(CBV);
54 break;
57 // Now AND intermediate value with PredBB's block predicate if it has one.
58 VPValue *BP = PredBB->getPredicate();
59 if (BP)
60 return Builder.createAnd(BP, IntermediateVal);
61 else
62 return IntermediateVal;
65 // Generate a tree of ORs for all IncomingPredicates in WorkList.
66 // Note: This function destroys the original Worklist.
68 // P1 P2 P3 P4 P5
69 // \ / \ / /
70 // OR1 OR2 /
71 // \ | /
72 // \ +/-+
73 // \ / |
74 // OR3 |
75 // \ |
76 // OR4 <- Returns this
77 // |
79 // The algorithm uses a worklist of predicates as its main data structure.
80 // We pop a pair of values from the front (e.g. P1 and P2), generate an OR
81 // (in this example OR1), and push it back. In this example the worklist
82 // contains {P3, P4, P5, OR1}.
83 // The process iterates until we have only one element in the Worklist (OR4).
84 // The last element is the root predicate which is returned.
85 VPValue *VPlanPredicator::genPredicateTree(std::list<VPValue *> &Worklist) {
86 if (Worklist.empty())
87 return nullptr;
89 // The worklist initially contains all the leaf nodes. Initialize the tree
90 // using them.
91 while (Worklist.size() >= 2) {
92 // Pop a pair of values from the front.
93 VPValue *LHS = Worklist.front();
94 Worklist.pop_front();
95 VPValue *RHS = Worklist.front();
96 Worklist.pop_front();
98 // Create an OR of these values.
99 VPValue *Or = Builder.createOr(LHS, RHS);
101 // Push OR to the back of the worklist.
102 Worklist.push_back(Or);
105 assert(Worklist.size() == 1 && "Expected 1 item in worklist");
107 // The root is the last node in the worklist.
108 VPValue *Root = Worklist.front();
110 // This root needs to replace the existing block predicate. This is done in
111 // the caller function.
112 return Root;
115 // Return whether the edge FromBlock -> ToBlock is a TRUE_EDGE or FALSE_EDGE
116 VPlanPredicator::EdgeType
117 VPlanPredicator::getEdgeTypeBetween(VPBlockBase *FromBlock,
118 VPBlockBase *ToBlock) {
119 unsigned Count = 0;
120 for (VPBlockBase *SuccBlock : FromBlock->getSuccessors()) {
121 if (SuccBlock == ToBlock) {
122 assert(Count < 2 && "Switch not supported currently");
123 return (Count == 0) ? EdgeType::TRUE_EDGE : EdgeType::FALSE_EDGE;
125 Count++;
128 llvm_unreachable("Broken getEdgeTypeBetween");
131 // Generate all predicates needed for CurrBlock by going through its immediate
132 // predecessor blocks.
133 void VPlanPredicator::createOrPropagatePredicates(VPBlockBase *CurrBlock,
134 VPRegionBlock *Region) {
135 // Blocks that dominate region exit inherit the predicate from the region.
136 // Return after setting the predicate.
137 if (VPDomTree.dominates(CurrBlock, Region->getExit())) {
138 VPValue *RegionBP = Region->getPredicate();
139 CurrBlock->setPredicate(RegionBP);
140 return;
143 // Collect all incoming predicates in a worklist.
144 std::list<VPValue *> IncomingPredicates;
146 // Set the builder's insertion point to the top of the current BB
147 VPBasicBlock *CurrBB = cast<VPBasicBlock>(CurrBlock->getEntryBasicBlock());
148 Builder.setInsertPoint(CurrBB, CurrBB->begin());
150 // For each predecessor, generate the VPInstructions required for
151 // computing 'BP AND (not) CBV" at the top of CurrBB.
152 // Collect the outcome of this calculation for all predecessors
153 // into IncomingPredicates.
154 for (VPBlockBase *PredBlock : CurrBlock->getPredecessors()) {
155 // Skip back-edges
156 if (VPBlockUtils::isBackEdge(PredBlock, CurrBlock, VPLI))
157 continue;
159 VPValue *IncomingPredicate = nullptr;
160 unsigned NumPredSuccsNoBE =
161 VPBlockUtils::countSuccessorsNoBE(PredBlock, VPLI);
163 // If there is an unconditional branch to the currBB, then we don't create
164 // edge predicates. We use the predecessor's block predicate instead.
165 if (NumPredSuccsNoBE == 1)
166 IncomingPredicate = PredBlock->getPredicate();
167 else if (NumPredSuccsNoBE == 2) {
168 // Emit recipes into CurrBlock if required
169 assert(isa<VPBasicBlock>(PredBlock) && "Only BBs have multiple exits");
170 IncomingPredicate =
171 getOrCreateNotPredicate(cast<VPBasicBlock>(PredBlock), CurrBB);
172 } else
173 llvm_unreachable("FIXME: switch statement ?");
175 if (IncomingPredicate)
176 IncomingPredicates.push_back(IncomingPredicate);
179 // Logically OR all incoming predicates by building the Predicate Tree.
180 VPValue *Predicate = genPredicateTree(IncomingPredicates);
182 // Now update the block's predicate with the new one.
183 CurrBlock->setPredicate(Predicate);
186 // Generate all predicates needed for Region.
187 void VPlanPredicator::predicateRegionRec(VPRegionBlock *Region) {
188 VPBasicBlock *EntryBlock = cast<VPBasicBlock>(Region->getEntry());
189 ReversePostOrderTraversal<VPBlockBase *> RPOT(EntryBlock);
191 // Generate edge predicates and append them to the block predicate. RPO is
192 // necessary since the predecessor blocks' block predicate needs to be set
193 // before the current block's block predicate can be computed.
194 for (VPBlockBase *Block : make_range(RPOT.begin(), RPOT.end())) {
195 // TODO: Handle nested regions once we start generating the same.
196 assert(!isa<VPRegionBlock>(Block) && "Nested region not expected");
197 createOrPropagatePredicates(Block, Region);
201 // Linearize the CFG within Region.
202 // TODO: Predication and linearization need RPOT for every region.
203 // This traversal is expensive. Since predication is not adding new
204 // blocks, we should be able to compute RPOT once in predication and
205 // reuse it here. This becomes even more important once we have nested
206 // regions.
207 void VPlanPredicator::linearizeRegionRec(VPRegionBlock *Region) {
208 ReversePostOrderTraversal<VPBlockBase *> RPOT(Region->getEntry());
209 VPBlockBase *PrevBlock = nullptr;
211 for (VPBlockBase *CurrBlock : make_range(RPOT.begin(), RPOT.end())) {
212 // TODO: Handle nested regions once we start generating the same.
213 assert(!isa<VPRegionBlock>(CurrBlock) && "Nested region not expected");
215 // Linearize control flow by adding an unconditional edge between PrevBlock
216 // and CurrBlock skipping loop headers and latches to keep intact loop
217 // header predecessors and loop latch successors.
218 if (PrevBlock && !VPLI->isLoopHeader(CurrBlock) &&
219 !VPBlockUtils::blockIsLoopLatch(PrevBlock, VPLI)) {
221 LLVM_DEBUG(dbgs() << "Linearizing: " << PrevBlock->getName() << "->"
222 << CurrBlock->getName() << "\n");
224 PrevBlock->clearSuccessors();
225 CurrBlock->clearPredecessors();
226 VPBlockUtils::connectBlocks(PrevBlock, CurrBlock);
229 PrevBlock = CurrBlock;
233 // Entry point. The driver function for the predicator.
234 void VPlanPredicator::predicate(void) {
235 // Predicate the blocks within Region.
236 predicateRegionRec(cast<VPRegionBlock>(Plan.getEntry()));
238 // Linearlize the blocks with Region.
239 linearizeRegionRec(cast<VPRegionBlock>(Plan.getEntry()));
242 VPlanPredicator::VPlanPredicator(VPlan &Plan)
243 : Plan(Plan), VPLI(&(Plan.getVPLoopInfo())) {
244 // FIXME: Predicator is currently computing the dominator information for the
245 // top region. Once we start storing dominator information in a VPRegionBlock,
246 // we can avoid this recalculation.
247 VPDomTree.recalculate(*(cast<VPRegionBlock>(Plan.getEntry())));