1 //===- Inliner.cpp ---- SCC-based inliner ---------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements Inliner that uses a basic inlining
10 // algorithm that operates bottom up over the Strongly Connect Components(SCCs)
11 // of the CallGraph. This enables a more incremental propagation of inlining
12 // decisions from the leafs to the roots of the callgraph.
14 //===----------------------------------------------------------------------===//
16 #include "mlir/Transforms/Inliner.h"
17 #include "mlir/IR/Threading.h"
18 #include "mlir/Interfaces/CallInterfaces.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Support/DebugStringHelper.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/SCCIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/Support/Debug.h"
28 #define DEBUG_TYPE "inlining"
32 using ResolvedCall
= Inliner::ResolvedCall
;
34 //===----------------------------------------------------------------------===//
35 // Symbol Use Tracking
36 //===----------------------------------------------------------------------===//
38 /// Walk all of the used symbol callgraph nodes referenced with the given op.
39 static void walkReferencedSymbolNodes(
40 Operation
*op
, CallGraph
&cg
, SymbolTableCollection
&symbolTable
,
41 DenseMap
<Attribute
, CallGraphNode
*> &resolvedRefs
,
42 function_ref
<void(CallGraphNode
*, Operation
*)> callback
) {
43 auto symbolUses
= SymbolTable::getSymbolUses(op
);
44 assert(symbolUses
&& "expected uses to be valid");
46 Operation
*symbolTableOp
= op
->getParentOp();
47 for (const SymbolTable::SymbolUse
&use
: *symbolUses
) {
48 auto refIt
= resolvedRefs
.insert({use
.getSymbolRef(), nullptr});
49 CallGraphNode
*&node
= refIt
.first
->second
;
51 // If this is the first instance of this reference, try to resolve a
52 // callgraph node for it.
54 auto *symbolOp
= symbolTable
.lookupNearestSymbolFrom(symbolTableOp
,
56 auto callableOp
= dyn_cast_or_null
<CallableOpInterface
>(symbolOp
);
59 node
= cg
.lookupNode(callableOp
.getCallableRegion());
62 callback(node
, use
.getUser());
66 //===----------------------------------------------------------------------===//
70 /// This struct tracks the uses of callgraph nodes that can be dropped when
71 /// use_empty. It directly tracks and manages a use-list for all of the
72 /// call-graph nodes. This is necessary because many callgraph nodes are
73 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
76 /// This struct tracks the uses of callgraph nodes within a specific
79 /// Any nodes referenced in the top-level attribute list of this user. We
80 /// use a set here because the number of references does not matter.
81 DenseSet
<CallGraphNode
*> topLevelUses
;
83 /// Uses of nodes referenced by nested operations.
84 DenseMap
<CallGraphNode
*, int> innerUses
;
87 CGUseList(Operation
*op
, CallGraph
&cg
, SymbolTableCollection
&symbolTable
);
89 /// Drop uses of nodes referred to by the given call operation that resides
90 /// within 'userNode'.
91 void dropCallUses(CallGraphNode
*userNode
, Operation
*callOp
, CallGraph
&cg
);
93 /// Remove the given node from the use list.
94 void eraseNode(CallGraphNode
*node
);
96 /// Returns true if the given callgraph node has no uses and can be pruned.
97 bool isDead(CallGraphNode
*node
) const;
99 /// Returns true if the given callgraph node has a single use and can be
101 bool hasOneUseAndDiscardable(CallGraphNode
*node
) const;
103 /// Recompute the uses held by the given callgraph node.
104 void recomputeUses(CallGraphNode
*node
, CallGraph
&cg
);
106 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
107 /// of 'lhs' into 'rhs'.
108 void mergeUsesAfterInlining(CallGraphNode
*lhs
, CallGraphNode
*rhs
);
111 /// Decrement the uses of discardable nodes referenced by the given user.
112 void decrementDiscardableUses(CGUser
&uses
);
114 /// A mapping between a discardable callgraph node (that is a symbol) and the
115 /// number of uses for this node.
116 DenseMap
<CallGraphNode
*, int> discardableSymNodeUses
;
118 /// A mapping between a callgraph node and the symbol callgraph nodes that it
120 DenseMap
<CallGraphNode
*, CGUser
> nodeUses
;
122 /// A symbol table to use when resolving call lookups.
123 SymbolTableCollection
&symbolTable
;
127 CGUseList::CGUseList(Operation
*op
, CallGraph
&cg
,
128 SymbolTableCollection
&symbolTable
)
129 : symbolTable(symbolTable
) {
130 /// A set of callgraph nodes that are always known to be live during inlining.
131 DenseMap
<Attribute
, CallGraphNode
*> alwaysLiveNodes
;
133 // Walk each of the symbol tables looking for discardable callgraph nodes.
134 auto walkFn
= [&](Operation
*symbolTableOp
, bool allUsesVisible
) {
135 for (Operation
&op
: symbolTableOp
->getRegion(0).getOps()) {
136 // If this is a callgraph operation, check to see if it is discardable.
137 if (auto callable
= dyn_cast
<CallableOpInterface
>(&op
)) {
138 if (auto *node
= cg
.lookupNode(callable
.getCallableRegion())) {
139 SymbolOpInterface symbol
= dyn_cast
<SymbolOpInterface
>(&op
);
140 if (symbol
&& (allUsesVisible
|| symbol
.isPrivate()) &&
141 symbol
.canDiscardOnUseEmpty()) {
142 discardableSymNodeUses
.try_emplace(node
, 0);
147 // Otherwise, check for any referenced nodes. These will be always-live.
148 walkReferencedSymbolNodes(&op
, cg
, symbolTable
, alwaysLiveNodes
,
149 [](CallGraphNode
*, Operation
*) {});
152 SymbolTable::walkSymbolTables(op
, /*allSymUsesVisible=*/!op
->getBlock(),
155 // Drop the use information for any discardable nodes that are always live.
156 for (auto &it
: alwaysLiveNodes
)
157 discardableSymNodeUses
.erase(it
.second
);
159 // Compute the uses for each of the callable nodes in the graph.
160 for (CallGraphNode
*node
: cg
)
161 recomputeUses(node
, cg
);
164 void CGUseList::dropCallUses(CallGraphNode
*userNode
, Operation
*callOp
,
166 auto &userRefs
= nodeUses
[userNode
].innerUses
;
167 auto walkFn
= [&](CallGraphNode
*node
, Operation
*user
) {
168 auto parentIt
= userRefs
.find(node
);
169 if (parentIt
== userRefs
.end())
172 --discardableSymNodeUses
[node
];
174 DenseMap
<Attribute
, CallGraphNode
*> resolvedRefs
;
175 walkReferencedSymbolNodes(callOp
, cg
, symbolTable
, resolvedRefs
, walkFn
);
178 void CGUseList::eraseNode(CallGraphNode
*node
) {
179 // Drop all child nodes.
180 for (auto &edge
: *node
)
182 eraseNode(edge
.getTarget());
184 // Drop the uses held by this node and erase it.
185 auto useIt
= nodeUses
.find(node
);
186 assert(useIt
!= nodeUses
.end() && "expected node to be valid");
187 decrementDiscardableUses(useIt
->getSecond());
188 nodeUses
.erase(useIt
);
189 discardableSymNodeUses
.erase(node
);
192 bool CGUseList::isDead(CallGraphNode
*node
) const {
193 // If the parent operation isn't a symbol, simply check normal SSA deadness.
194 Operation
*nodeOp
= node
->getCallableRegion()->getParentOp();
195 if (!isa
<SymbolOpInterface
>(nodeOp
))
196 return isMemoryEffectFree(nodeOp
) && nodeOp
->use_empty();
198 // Otherwise, check the number of symbol uses.
199 auto symbolIt
= discardableSymNodeUses
.find(node
);
200 return symbolIt
!= discardableSymNodeUses
.end() && symbolIt
->second
== 0;
203 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode
*node
) const {
204 // If this isn't a symbol node, check for side-effects and SSA use count.
205 Operation
*nodeOp
= node
->getCallableRegion()->getParentOp();
206 if (!isa
<SymbolOpInterface
>(nodeOp
))
207 return isMemoryEffectFree(nodeOp
) && nodeOp
->hasOneUse();
209 // Otherwise, check the number of symbol uses.
210 auto symbolIt
= discardableSymNodeUses
.find(node
);
211 return symbolIt
!= discardableSymNodeUses
.end() && symbolIt
->second
== 1;
214 void CGUseList::recomputeUses(CallGraphNode
*node
, CallGraph
&cg
) {
215 Operation
*parentOp
= node
->getCallableRegion()->getParentOp();
216 CGUser
&uses
= nodeUses
[node
];
217 decrementDiscardableUses(uses
);
219 // Collect the new discardable uses within this node.
221 DenseMap
<Attribute
, CallGraphNode
*> resolvedRefs
;
222 auto walkFn
= [&](CallGraphNode
*refNode
, Operation
*user
) {
223 auto discardSymIt
= discardableSymNodeUses
.find(refNode
);
224 if (discardSymIt
== discardableSymNodeUses
.end())
227 if (user
!= parentOp
)
228 ++uses
.innerUses
[refNode
];
229 else if (!uses
.topLevelUses
.insert(refNode
).second
)
231 ++discardSymIt
->second
;
233 walkReferencedSymbolNodes(parentOp
, cg
, symbolTable
, resolvedRefs
, walkFn
);
236 void CGUseList::mergeUsesAfterInlining(CallGraphNode
*lhs
, CallGraphNode
*rhs
) {
237 auto &lhsUses
= nodeUses
[lhs
], &rhsUses
= nodeUses
[rhs
];
238 for (auto &useIt
: lhsUses
.innerUses
) {
239 rhsUses
.innerUses
[useIt
.first
] += useIt
.second
;
240 discardableSymNodeUses
[useIt
.first
] += useIt
.second
;
244 void CGUseList::decrementDiscardableUses(CGUser
&uses
) {
245 for (CallGraphNode
*node
: uses
.topLevelUses
)
246 --discardableSymNodeUses
[node
];
247 for (auto &it
: uses
.innerUses
)
248 discardableSymNodeUses
[it
.first
] -= it
.second
;
251 //===----------------------------------------------------------------------===//
252 // CallGraph traversal
253 //===----------------------------------------------------------------------===//
256 /// This class represents a specific callgraph SCC.
259 CallGraphSCC(llvm::scc_iterator
<const CallGraph
*> &parentIterator
)
260 : parentIterator(parentIterator
) {}
261 /// Return a range over the nodes within this SCC.
262 std::vector
<CallGraphNode
*>::iterator
begin() { return nodes
.begin(); }
263 std::vector
<CallGraphNode
*>::iterator
end() { return nodes
.end(); }
265 /// Reset the nodes of this SCC with those provided.
266 void reset(const std::vector
<CallGraphNode
*> &newNodes
) { nodes
= newNodes
; }
268 /// Remove the given node from this SCC.
269 void remove(CallGraphNode
*node
) {
270 auto it
= llvm::find(nodes
, node
);
271 if (it
!= nodes
.end()) {
273 parentIterator
.ReplaceNode(node
, nullptr);
278 std::vector
<CallGraphNode
*> nodes
;
279 llvm::scc_iterator
<const CallGraph
*> &parentIterator
;
283 /// Run a given transformation over the SCCs of the callgraph in a bottom up
285 static LogicalResult
runTransformOnCGSCCs(
287 function_ref
<LogicalResult(CallGraphSCC
&)> sccTransformer
) {
288 llvm::scc_iterator
<const CallGraph
*> cgi
= llvm::scc_begin(&cg
);
289 CallGraphSCC
currentSCC(cgi
);
290 while (!cgi
.isAtEnd()) {
291 // Copy the current SCC and increment so that the transformer can modify the
292 // SCC without invalidating our iterator.
293 currentSCC
.reset(*cgi
);
295 if (failed(sccTransformer(currentSCC
)))
301 /// Collect all of the callable operations within the given range of blocks. If
302 /// `traverseNestedCGNodes` is true, this will also collect call operations
303 /// inside of nested callgraph nodes.
304 static void collectCallOps(iterator_range
<Region::iterator
> blocks
,
305 CallGraphNode
*sourceNode
, CallGraph
&cg
,
306 SymbolTableCollection
&symbolTable
,
307 SmallVectorImpl
<ResolvedCall
> &calls
,
308 bool traverseNestedCGNodes
) {
309 SmallVector
<std::pair
<Block
*, CallGraphNode
*>, 8> worklist
;
310 auto addToWorklist
= [&](CallGraphNode
*node
,
311 iterator_range
<Region::iterator
> blocks
) {
312 for (Block
&block
: blocks
)
313 worklist
.emplace_back(&block
, node
);
316 addToWorklist(sourceNode
, blocks
);
317 while (!worklist
.empty()) {
319 std::tie(block
, sourceNode
) = worklist
.pop_back_val();
321 for (Operation
&op
: *block
) {
322 if (auto call
= dyn_cast
<CallOpInterface
>(op
)) {
323 // TODO: Support inlining nested call references.
324 CallInterfaceCallable callable
= call
.getCallableForCallee();
325 if (SymbolRefAttr symRef
= dyn_cast
<SymbolRefAttr
>(callable
)) {
326 if (!isa
<FlatSymbolRefAttr
>(symRef
))
330 CallGraphNode
*targetNode
= cg
.resolveCallable(call
, symbolTable
);
331 if (!targetNode
->isExternal())
332 calls
.emplace_back(call
, sourceNode
, targetNode
);
336 // If this is not a call, traverse the nested regions. If
337 // `traverseNestedCGNodes` is false, then don't traverse nested call graph
339 for (auto &nestedRegion
: op
.getRegions()) {
340 CallGraphNode
*nestedNode
= cg
.lookupNode(&nestedRegion
);
341 if (traverseNestedCGNodes
|| !nestedNode
)
342 addToWorklist(nestedNode
? nestedNode
: sourceNode
, nestedRegion
);
348 //===----------------------------------------------------------------------===//
349 // InlinerInterfaceImpl
350 //===----------------------------------------------------------------------===//
353 static std::string
getNodeName(CallOpInterface op
) {
354 if (llvm::dyn_cast_if_present
<SymbolRefAttr
>(op
.getCallableForCallee()))
355 return debugString(op
);
356 return "_unnamed_callee_";
360 /// Return true if the specified `inlineHistoryID` indicates an inline history
361 /// that already includes `node`.
362 static bool inlineHistoryIncludes(
363 CallGraphNode
*node
, std::optional
<size_t> inlineHistoryID
,
364 MutableArrayRef
<std::pair
<CallGraphNode
*, std::optional
<size_t>>>
366 while (inlineHistoryID
.has_value()) {
367 assert(*inlineHistoryID
< inlineHistory
.size() &&
368 "Invalid inline history ID");
369 if (inlineHistory
[*inlineHistoryID
].first
== node
)
371 inlineHistoryID
= inlineHistory
[*inlineHistoryID
].second
;
377 /// This class provides a specialization of the main inlining interface.
378 struct InlinerInterfaceImpl
: public InlinerInterface
{
379 InlinerInterfaceImpl(MLIRContext
*context
, CallGraph
&cg
,
380 SymbolTableCollection
&symbolTable
)
381 : InlinerInterface(context
), cg(cg
), symbolTable(symbolTable
) {}
383 /// Process a set of blocks that have been inlined. This callback is invoked
384 /// *before* inlined terminator operations have been processed.
386 processInlinedBlocks(iterator_range
<Region::iterator
> inlinedBlocks
) final
{
387 // Find the closest callgraph node from the first block.
389 Region
*region
= inlinedBlocks
.begin()->getParent();
390 while (!(node
= cg
.lookupNode(region
))) {
391 region
= region
->getParentRegion();
392 assert(region
&& "expected valid parent node");
395 collectCallOps(inlinedBlocks
, node
, cg
, symbolTable
, calls
,
396 /*traverseNestedCGNodes=*/true);
399 /// Mark the given callgraph node for deletion.
400 void markForDeletion(CallGraphNode
*node
) { deadNodes
.insert(node
); }
402 /// This method properly disposes of callables that became dead during
403 /// inlining. This should not be called while iterating over the SCCs.
404 void eraseDeadCallables() {
405 for (CallGraphNode
*node
: deadNodes
)
406 node
->getCallableRegion()->getParentOp()->erase();
409 /// The set of callables known to be dead.
410 SmallPtrSet
<CallGraphNode
*, 8> deadNodes
;
412 /// The current set of call instructions to consider for inlining.
413 SmallVector
<ResolvedCall
, 8> calls
;
415 /// The callgraph being operated on.
418 /// A symbol table to use when resolving call lookups.
419 SymbolTableCollection
&symbolTable
;
425 class Inliner::Impl
{
427 Impl(Inliner
&inliner
) : inliner(inliner
) {}
429 /// Attempt to inline calls within the given scc, and run simplifications,
430 /// until a fixed point is reached. This allows for the inlining of newly
431 /// devirtualized calls. Returns failure if there was a fatal error during
433 LogicalResult
inlineSCC(InlinerInterfaceImpl
&inlinerIface
,
434 CGUseList
&useList
, CallGraphSCC
¤tSCC
,
435 MLIRContext
*context
);
438 /// Optimize the nodes within the given SCC with one of the held optimization
439 /// pass pipelines. Returns failure if an error occurred during the
440 /// optimization of the SCC, success otherwise.
441 LogicalResult
optimizeSCC(CallGraph
&cg
, CGUseList
&useList
,
442 CallGraphSCC
¤tSCC
, MLIRContext
*context
);
444 /// Optimize the nodes within the given SCC in parallel. Returns failure if an
445 /// error occurred during the optimization of the SCC, success otherwise.
446 LogicalResult
optimizeSCCAsync(MutableArrayRef
<CallGraphNode
*> nodesToVisit
,
447 MLIRContext
*context
);
449 /// Optimize the given callable node with one of the pass managers provided
450 /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if
451 /// an error occurred during the optimization of the callable, success
453 LogicalResult
optimizeCallable(CallGraphNode
*node
,
454 llvm::StringMap
<OpPassManager
> &pipelines
);
456 /// Attempt to inline calls within the given scc. This function returns
457 /// success if any calls were inlined, failure otherwise.
458 LogicalResult
inlineCallsInSCC(InlinerInterfaceImpl
&inlinerIface
,
459 CGUseList
&useList
, CallGraphSCC
¤tSCC
);
461 /// Returns true if the given call should be inlined.
462 bool shouldInline(ResolvedCall
&resolvedCall
);
466 llvm::SmallVector
<llvm::StringMap
<OpPassManager
>> pipelines
;
469 LogicalResult
Inliner::Impl::inlineSCC(InlinerInterfaceImpl
&inlinerIface
,
471 CallGraphSCC
¤tSCC
,
472 MLIRContext
*context
) {
473 // Continuously simplify and inline until we either reach a fixed point, or
474 // hit the maximum iteration count. Simplifying early helps to refine the cost
475 // model, and in future iterations may devirtualize new calls.
476 unsigned iterationCount
= 0;
478 if (failed(optimizeSCC(inlinerIface
.cg
, useList
, currentSCC
, context
)))
480 if (failed(inlineCallsInSCC(inlinerIface
, useList
, currentSCC
)))
482 } while (++iterationCount
< inliner
.config
.getMaxInliningIterations());
486 LogicalResult
Inliner::Impl::optimizeSCC(CallGraph
&cg
, CGUseList
&useList
,
487 CallGraphSCC
¤tSCC
,
488 MLIRContext
*context
) {
489 // Collect the sets of nodes to simplify.
490 SmallVector
<CallGraphNode
*, 4> nodesToVisit
;
491 for (auto *node
: currentSCC
) {
492 if (node
->isExternal())
495 // Don't simplify nodes with children. Nodes with children require special
496 // handling as we may remove the node during simplification. In the future,
497 // we should be able to handle this case with proper node deletion tracking.
498 if (node
->hasChildren())
501 // We also won't apply simplifications to nodes that can't have passes
502 // scheduled on them.
503 auto *region
= node
->getCallableRegion();
504 if (!region
->getParentOp()->hasTrait
<OpTrait::IsIsolatedFromAbove
>())
506 nodesToVisit
.push_back(node
);
508 if (nodesToVisit
.empty())
511 // Optimize each of the nodes within the SCC in parallel.
512 if (failed(optimizeSCCAsync(nodesToVisit
, context
)))
515 // Recompute the uses held by each of the nodes.
516 for (CallGraphNode
*node
: nodesToVisit
)
517 useList
.recomputeUses(node
, cg
);
522 Inliner::Impl::optimizeSCCAsync(MutableArrayRef
<CallGraphNode
*> nodesToVisit
,
524 // We must maintain a fixed pool of pass managers which is at least as large
525 // as the maximum parallelism of the failableParallelForEach below.
526 // Note: The number of pass managers here needs to remain constant
527 // to prevent issues with pass instrumentations that rely on having the same
528 // pass manager for the main thread.
529 size_t numThreads
= ctx
->getNumThreads();
530 const auto &opPipelines
= inliner
.config
.getOpPipelines();
531 if (pipelines
.size() < numThreads
) {
532 pipelines
.reserve(numThreads
);
533 pipelines
.resize(numThreads
, opPipelines
);
536 // Ensure an analysis manager has been constructed for each of the nodes.
537 // This prevents thread races when running the nested pipelines.
538 for (CallGraphNode
*node
: nodesToVisit
)
539 inliner
.am
.nest(node
->getCallableRegion()->getParentOp());
541 // An atomic failure variable for the async executors.
542 std::vector
<std::atomic
<bool>> activePMs(pipelines
.size());
543 std::fill(activePMs
.begin(), activePMs
.end(), false);
544 return failableParallelForEach(ctx
, nodesToVisit
, [&](CallGraphNode
*node
) {
545 // Find a pass manager for this operation.
546 auto it
= llvm::find_if(activePMs
, [](std::atomic
<bool> &isActive
) {
547 bool expectedInactive
= false;
548 return isActive
.compare_exchange_strong(expectedInactive
, true);
550 assert(it
!= activePMs
.end() &&
551 "could not find inactive pass manager for thread");
552 unsigned pmIndex
= it
- activePMs
.begin();
554 // Optimize this callable node.
555 LogicalResult result
= optimizeCallable(node
, pipelines
[pmIndex
]);
557 // Reset the active bit for this pass manager.
558 activePMs
[pmIndex
].store(false);
564 Inliner::Impl::optimizeCallable(CallGraphNode
*node
,
565 llvm::StringMap
<OpPassManager
> &pipelines
) {
566 Operation
*callable
= node
->getCallableRegion()->getParentOp();
567 StringRef opName
= callable
->getName().getStringRef();
568 auto pipelineIt
= pipelines
.find(opName
);
569 const auto &defaultPipeline
= inliner
.config
.getDefaultPipeline();
570 if (pipelineIt
== pipelines
.end()) {
571 // If a pipeline didn't exist, use the generic pipeline if possible.
572 if (!defaultPipeline
)
575 OpPassManager
defaultPM(opName
);
576 defaultPipeline(defaultPM
);
577 pipelineIt
= pipelines
.try_emplace(opName
, std::move(defaultPM
)).first
;
579 return inliner
.runPipelineHelper(inliner
.pass
, pipelineIt
->second
, callable
);
582 /// Attempt to inline calls within the given scc. This function returns
583 /// success if any calls were inlined, failure otherwise.
585 Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl
&inlinerIface
,
586 CGUseList
&useList
, CallGraphSCC
¤tSCC
) {
587 CallGraph
&cg
= inlinerIface
.cg
;
588 auto &calls
= inlinerIface
.calls
;
590 // A set of dead nodes to remove after inlining.
591 llvm::SmallSetVector
<CallGraphNode
*, 1> deadNodes
;
593 // Collect all of the direct calls within the nodes of the current SCC. We
594 // don't traverse nested callgraph nodes, because they are handled separately
595 // likely within a different SCC.
596 for (CallGraphNode
*node
: currentSCC
) {
597 if (node
->isExternal())
600 // Don't collect calls if the node is already dead.
601 if (useList
.isDead(node
)) {
602 deadNodes
.insert(node
);
604 collectCallOps(*node
->getCallableRegion(), node
, cg
,
605 inlinerIface
.symbolTable
, calls
,
606 /*traverseNestedCGNodes=*/false);
610 // When inlining a callee produces new call sites, we want to keep track of
611 // the fact that they were inlined from the callee. This allows us to avoid
612 // infinite inlining.
613 using InlineHistoryT
= std::optional
<size_t>;
614 SmallVector
<std::pair
<CallGraphNode
*, InlineHistoryT
>, 8> inlineHistory
;
615 std::vector
<InlineHistoryT
> callHistory(calls
.size(), InlineHistoryT
{});
618 llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
619 for (unsigned i
= 0, e
= calls
.size(); i
< e
; ++i
)
620 llvm::dbgs() << " " << i
<< ". " << calls
[i
].call
<< ",\n";
621 llvm::dbgs() << "}\n";
624 // Try to inline each of the call operations. Don't cache the end iterator
625 // here as more calls may be added during inlining.
626 bool inlinedAnyCalls
= false;
627 for (unsigned i
= 0; i
< calls
.size(); ++i
) {
628 if (deadNodes
.contains(calls
[i
].sourceNode
))
630 ResolvedCall it
= calls
[i
];
632 InlineHistoryT inlineHistoryID
= callHistory
[i
];
634 inlineHistoryIncludes(it
.targetNode
, inlineHistoryID
, inlineHistory
);
635 bool doInline
= !inHistory
&& shouldInline(it
);
636 CallOpInterface call
= it
.call
;
639 llvm::dbgs() << "* Inlining call: " << i
<< ". " << call
<< "\n";
641 llvm::dbgs() << "* Not inlining call: " << i
<< ". " << call
<< "\n";
646 unsigned prevSize
= calls
.size();
647 Region
*targetRegion
= it
.targetNode
->getCallableRegion();
649 // If this is the last call to the target node and the node is discardable,
650 // then inline it in-place and delete the node if successful.
651 bool inlineInPlace
= useList
.hasOneUseAndDiscardable(it
.targetNode
);
653 LogicalResult inlineResult
=
654 inlineCall(inlinerIface
, call
,
655 cast
<CallableOpInterface
>(targetRegion
->getParentOp()),
656 targetRegion
, /*shouldCloneInlinedRegion=*/!inlineInPlace
);
657 if (failed(inlineResult
)) {
658 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
661 inlinedAnyCalls
= true;
663 // Create a inline history entry for this inlined call, so that we remember
664 // that new callsites came about due to inlining Callee.
665 InlineHistoryT newInlineHistoryID
{inlineHistory
.size()};
666 inlineHistory
.push_back(std::make_pair(it
.targetNode
, inlineHistoryID
));
668 auto historyToString
= [](InlineHistoryT h
) {
669 return h
.has_value() ? std::to_string(*h
) : "root";
671 (void)historyToString
;
672 LLVM_DEBUG(llvm::dbgs()
673 << "* new inlineHistory entry: " << newInlineHistoryID
<< ". ["
674 << getNodeName(call
) << ", " << historyToString(inlineHistoryID
)
677 for (unsigned k
= prevSize
; k
!= calls
.size(); ++k
) {
678 callHistory
.push_back(newInlineHistoryID
);
679 LLVM_DEBUG(llvm::dbgs() << "* new call " << k
<< " {" << calls
[i
].call
680 << "}\n with historyID = " << newInlineHistoryID
681 << ", added due to inlining of\n call {" << call
682 << "}\n with historyID = "
683 << historyToString(inlineHistoryID
) << "\n");
686 // If the inlining was successful, Merge the new uses into the source node.
687 useList
.dropCallUses(it
.sourceNode
, call
.getOperation(), cg
);
688 useList
.mergeUsesAfterInlining(it
.targetNode
, it
.sourceNode
);
690 // then erase the call.
693 // If we inlined in place, mark the node for deletion.
695 useList
.eraseNode(it
.targetNode
);
696 deadNodes
.insert(it
.targetNode
);
700 for (CallGraphNode
*node
: deadNodes
) {
701 currentSCC
.remove(node
);
702 inlinerIface
.markForDeletion(node
);
705 return success(inlinedAnyCalls
);
708 /// Returns true if the given call should be inlined.
709 bool Inliner::Impl::shouldInline(ResolvedCall
&resolvedCall
) {
710 // Don't allow inlining terminator calls. We currently don't support this
712 if (resolvedCall
.call
->hasTrait
<OpTrait::IsTerminator
>())
715 // Don't allow inlining if the target is a self-recursive function.
716 if (llvm::count_if(*resolvedCall
.targetNode
,
717 [&](CallGraphNode::Edge
const &edge
) -> bool {
718 return edge
.getTarget() == resolvedCall
.targetNode
;
722 // Don't allow inlining if the target is an ancestor of the call. This
723 // prevents inlining recursively.
724 Region
*callableRegion
= resolvedCall
.targetNode
->getCallableRegion();
725 if (callableRegion
->isAncestor(resolvedCall
.call
->getParentRegion()))
728 // Don't allow inlining if the callee has multiple blocks (unstructured
729 // control flow) but we cannot be sure that the caller region supports that.
730 bool calleeHasMultipleBlocks
=
731 llvm::hasNItemsOrMore(*callableRegion
, /*N=*/2);
732 // If both parent ops have the same type, it is safe to inline. Otherwise,
733 // decide based on whether the op has the SingleBlock trait or not.
734 // Note: This check does currently not account for SizedRegion/MaxSizedRegion.
735 auto callerRegionSupportsMultipleBlocks
= [&]() {
736 return callableRegion
->getParentOp()->getName() ==
737 resolvedCall
.call
->getParentOp()->getName() ||
738 !resolvedCall
.call
->getParentOp()
739 ->mightHaveTrait
<OpTrait::SingleBlock
>();
741 if (calleeHasMultipleBlocks
&& !callerRegionSupportsMultipleBlocks())
744 if (!inliner
.isProfitableToInline(resolvedCall
))
747 // Otherwise, inline.
751 LogicalResult
Inliner::doInlining() {
753 auto *context
= op
->getContext();
754 // Run the inline transform in post-order over the SCCs in the callgraph.
755 SymbolTableCollection symbolTable
;
756 // FIXME: some clean-up can be done for the arguments
757 // of the Impl's methods, if the inlinerIface and useList
758 // become the states of the Impl.
759 InlinerInterfaceImpl
inlinerIface(context
, cg
, symbolTable
);
760 CGUseList
useList(op
, cg
, symbolTable
);
761 LogicalResult result
= runTransformOnCGSCCs(cg
, [&](CallGraphSCC
&scc
) {
762 return impl
.inlineSCC(inlinerIface
, useList
, scc
, context
);
767 // After inlining, make sure to erase any callables proven to be dead.
768 inlinerIface
.eraseDeadCallables();