1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements miscellaneous inlining utilities.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Transforms/InliningUtils.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Interfaces/CallInterfaces.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
24 #define DEBUG_TYPE "inlining"
28 /// Remap all locations reachable from the inlined blocks with CallSiteLoc
29 /// locations with the provided caller location.
31 remapInlinedLocations(iterator_range
<Region::iterator
> inlinedBlocks
,
33 DenseMap
<Location
, LocationAttr
> mappedLocations
;
34 auto remapLoc
= [&](Location loc
) {
35 auto [it
, inserted
] = mappedLocations
.try_emplace(loc
);
36 // Only query the attribute uniquer once per callsite attribute.
38 auto newLoc
= CallSiteLoc::get(loc
, callerLoc
);
39 it
->getSecond() = newLoc
;
44 AttrTypeReplacer attrReplacer
;
45 attrReplacer
.addReplacement(
46 [&](LocationAttr loc
) -> std::pair
<LocationAttr
, WalkResult
> {
47 return {remapLoc(loc
), WalkResult::skip()};
50 for (Block
&block
: inlinedBlocks
) {
51 for (BlockArgument
&arg
: block
.getArguments())
52 if (LocationAttr newLoc
= remapLoc(arg
.getLoc()))
55 for (Operation
&op
: block
)
56 attrReplacer
.recursivelyReplaceElementsIn(&op
, /*replaceAttrs=*/false,
57 /*replaceLocs=*/true);
61 static void remapInlinedOperands(iterator_range
<Region::iterator
> inlinedBlocks
,
63 auto remapOperands
= [&](Operation
*op
) {
64 for (auto &operand
: op
->getOpOperands())
65 if (auto mappedOp
= mapper
.lookupOrNull(operand
.get()))
66 operand
.set(mappedOp
);
68 for (auto &block
: inlinedBlocks
)
69 block
.walk(remapOperands
);
72 //===----------------------------------------------------------------------===//
74 //===----------------------------------------------------------------------===//
76 bool InlinerInterface::isLegalToInline(Operation
*call
, Operation
*callable
,
77 bool wouldBeCloned
) const {
78 if (auto *handler
= getInterfaceFor(call
))
79 return handler
->isLegalToInline(call
, callable
, wouldBeCloned
);
83 bool InlinerInterface::isLegalToInline(Region
*dest
, Region
*src
,
85 IRMapping
&valueMapping
) const {
86 if (auto *handler
= getInterfaceFor(dest
->getParentOp()))
87 return handler
->isLegalToInline(dest
, src
, wouldBeCloned
, valueMapping
);
91 bool InlinerInterface::isLegalToInline(Operation
*op
, Region
*dest
,
93 IRMapping
&valueMapping
) const {
94 if (auto *handler
= getInterfaceFor(op
))
95 return handler
->isLegalToInline(op
, dest
, wouldBeCloned
, valueMapping
);
99 bool InlinerInterface::shouldAnalyzeRecursively(Operation
*op
) const {
100 auto *handler
= getInterfaceFor(op
);
101 return handler
? handler
->shouldAnalyzeRecursively(op
) : true;
104 /// Handle the given inlined terminator by replacing it with a new operation
106 void InlinerInterface::handleTerminator(Operation
*op
, Block
*newDest
) const {
107 auto *handler
= getInterfaceFor(op
);
108 assert(handler
&& "expected valid dialect handler");
109 handler
->handleTerminator(op
, newDest
);
112 /// Handle the given inlined terminator by replacing it with a new operation
114 void InlinerInterface::handleTerminator(Operation
*op
,
115 ValueRange valuesToRepl
) const {
116 auto *handler
= getInterfaceFor(op
);
117 assert(handler
&& "expected valid dialect handler");
118 handler
->handleTerminator(op
, valuesToRepl
);
121 Value
InlinerInterface::handleArgument(OpBuilder
&builder
, Operation
*call
,
122 Operation
*callable
, Value argument
,
123 DictionaryAttr argumentAttrs
) const {
124 auto *handler
= getInterfaceFor(callable
);
125 assert(handler
&& "expected valid dialect handler");
126 return handler
->handleArgument(builder
, call
, callable
, argument
,
130 Value
InlinerInterface::handleResult(OpBuilder
&builder
, Operation
*call
,
131 Operation
*callable
, Value result
,
132 DictionaryAttr resultAttrs
) const {
133 auto *handler
= getInterfaceFor(callable
);
134 assert(handler
&& "expected valid dialect handler");
135 return handler
->handleResult(builder
, call
, callable
, result
, resultAttrs
);
138 void InlinerInterface::processInlinedCallBlocks(
139 Operation
*call
, iterator_range
<Region::iterator
> inlinedBlocks
) const {
140 auto *handler
= getInterfaceFor(call
);
141 assert(handler
&& "expected valid dialect handler");
142 handler
->processInlinedCallBlocks(call
, inlinedBlocks
);
145 /// Utility to check that all of the operations within 'src' can be inlined.
146 static bool isLegalToInline(InlinerInterface
&interface
, Region
*src
,
147 Region
*insertRegion
, bool shouldCloneInlinedRegion
,
148 IRMapping
&valueMapping
) {
149 for (auto &block
: *src
) {
150 for (auto &op
: block
) {
151 // Check this operation.
152 if (!interface
.isLegalToInline(&op
, insertRegion
,
153 shouldCloneInlinedRegion
, valueMapping
)) {
155 llvm::dbgs() << "* Illegal to inline because of op: ";
160 // Check any nested regions.
161 if (interface
.shouldAnalyzeRecursively(&op
) &&
162 llvm::any_of(op
.getRegions(), [&](Region
®ion
) {
163 return !isLegalToInline(interface
, ®ion
, insertRegion
,
164 shouldCloneInlinedRegion
, valueMapping
);
172 //===----------------------------------------------------------------------===//
174 //===----------------------------------------------------------------------===//
176 static void handleArgumentImpl(InlinerInterface
&interface
, OpBuilder
&builder
,
177 CallOpInterface call
,
178 CallableOpInterface callable
,
180 // Unpack the argument attributes if there are any.
181 SmallVector
<DictionaryAttr
> argAttrs(
182 callable
.getCallableRegion()->getNumArguments(),
183 builder
.getDictionaryAttr({}));
184 if (ArrayAttr arrayAttr
= callable
.getArgAttrsAttr()) {
185 assert(arrayAttr
.size() == argAttrs
.size());
186 for (auto [idx
, attr
] : llvm::enumerate(arrayAttr
))
187 argAttrs
[idx
] = cast
<DictionaryAttr
>(attr
);
190 // Run the argument attribute handler for the given argument and attribute.
191 for (auto [blockArg
, argAttr
] :
192 llvm::zip(callable
.getCallableRegion()->getArguments(), argAttrs
)) {
193 Value newArgument
= interface
.handleArgument(
194 builder
, call
, callable
, mapper
.lookup(blockArg
), argAttr
);
195 assert(newArgument
.getType() == mapper
.lookup(blockArg
).getType() &&
196 "expected the argument type to not change");
198 // Update the mapping to point the new argument returned by the handler.
199 mapper
.map(blockArg
, newArgument
);
203 static void handleResultImpl(InlinerInterface
&interface
, OpBuilder
&builder
,
204 CallOpInterface call
, CallableOpInterface callable
,
205 ValueRange results
) {
206 // Unpack the result attributes if there are any.
207 SmallVector
<DictionaryAttr
> resAttrs(results
.size(),
208 builder
.getDictionaryAttr({}));
209 if (ArrayAttr arrayAttr
= callable
.getResAttrsAttr()) {
210 assert(arrayAttr
.size() == resAttrs
.size());
211 for (auto [idx
, attr
] : llvm::enumerate(arrayAttr
))
212 resAttrs
[idx
] = cast
<DictionaryAttr
>(attr
);
215 // Run the result attribute handler for the given result and attribute.
216 SmallVector
<DictionaryAttr
> resultAttributes
;
217 for (auto [result
, resAttr
] : llvm::zip(results
, resAttrs
)) {
218 // Store the original result users before running the handler.
219 DenseSet
<Operation
*> resultUsers
;
220 for (Operation
*user
: result
.getUsers())
221 resultUsers
.insert(user
);
224 interface
.handleResult(builder
, call
, callable
, result
, resAttr
);
225 assert(newResult
.getType() == result
.getType() &&
226 "expected the result type to not change");
228 // Replace the result uses except for the ones introduce by the handler.
229 result
.replaceUsesWithIf(newResult
, [&](OpOperand
&operand
) {
230 return resultUsers
.count(operand
.getOwner());
236 inlineRegionImpl(InlinerInterface
&interface
, Region
*src
, Block
*inlineBlock
,
237 Block::iterator inlinePoint
, IRMapping
&mapper
,
238 ValueRange resultsToReplace
, TypeRange regionResultTypes
,
239 std::optional
<Location
> inlineLoc
,
240 bool shouldCloneInlinedRegion
, CallOpInterface call
= {}) {
241 assert(resultsToReplace
.size() == regionResultTypes
.size());
242 // We expect the region to have at least one block.
246 // Check that all of the region arguments have been mapped.
247 auto *srcEntryBlock
= &src
->front();
248 if (llvm::any_of(srcEntryBlock
->getArguments(),
249 [&](BlockArgument arg
) { return !mapper
.contains(arg
); }))
252 // Check that the operations within the source region are valid to inline.
253 Region
*insertRegion
= inlineBlock
->getParent();
254 if (!interface
.isLegalToInline(insertRegion
, src
, shouldCloneInlinedRegion
,
256 !isLegalToInline(interface
, src
, insertRegion
, shouldCloneInlinedRegion
,
260 // Run the argument attribute handler before inlining the callable region.
261 OpBuilder
builder(inlineBlock
, inlinePoint
);
262 auto callable
= dyn_cast
<CallableOpInterface
>(src
->getParentOp());
263 if (call
&& callable
)
264 handleArgumentImpl(interface
, builder
, call
, callable
, mapper
);
266 // Check to see if the region is being cloned, or moved inline. In either
267 // case, move the new blocks after the 'insertBlock' to improve IR
269 Block
*postInsertBlock
= inlineBlock
->splitBlock(inlinePoint
);
270 if (shouldCloneInlinedRegion
)
271 src
->cloneInto(insertRegion
, postInsertBlock
->getIterator(), mapper
);
273 insertRegion
->getBlocks().splice(postInsertBlock
->getIterator(),
274 src
->getBlocks(), src
->begin(),
277 // Get the range of newly inserted blocks.
278 auto newBlocks
= llvm::make_range(std::next(inlineBlock
->getIterator()),
279 postInsertBlock
->getIterator());
280 Block
*firstNewBlock
= &*newBlocks
.begin();
282 // Remap the locations of the inlined operations if a valid source location
284 if (inlineLoc
&& !llvm::isa
<UnknownLoc
>(*inlineLoc
))
285 remapInlinedLocations(newBlocks
, *inlineLoc
);
287 // If the blocks were moved in-place, make sure to remap any necessary
289 if (!shouldCloneInlinedRegion
)
290 remapInlinedOperands(newBlocks
, mapper
);
292 // Process the newly inlined blocks.
294 interface
.processInlinedCallBlocks(call
, newBlocks
);
295 interface
.processInlinedBlocks(newBlocks
);
297 // Handle the case where only a single block was inlined.
298 if (std::next(newBlocks
.begin()) == newBlocks
.end()) {
299 // Run the result attribute handler on the terminator operands.
300 Operation
*firstBlockTerminator
= firstNewBlock
->getTerminator();
301 builder
.setInsertionPoint(firstBlockTerminator
);
302 if (call
&& callable
)
303 handleResultImpl(interface
, builder
, call
, callable
,
304 firstBlockTerminator
->getOperands());
306 // Have the interface handle the terminator of this block.
307 interface
.handleTerminator(firstBlockTerminator
, resultsToReplace
);
308 firstBlockTerminator
->erase();
310 // Merge the post insert block into the cloned entry block.
311 firstNewBlock
->getOperations().splice(firstNewBlock
->end(),
312 postInsertBlock
->getOperations());
313 postInsertBlock
->erase();
315 // Otherwise, there were multiple blocks inlined. Add arguments to the post
316 // insertion block to represent the results to replace.
317 for (const auto &resultToRepl
: llvm::enumerate(resultsToReplace
)) {
318 resultToRepl
.value().replaceAllUsesWith(
319 postInsertBlock
->addArgument(regionResultTypes
[resultToRepl
.index()],
320 resultToRepl
.value().getLoc()));
323 // Run the result attribute handler on the post insertion block arguments.
324 builder
.setInsertionPointToStart(postInsertBlock
);
325 if (call
&& callable
)
326 handleResultImpl(interface
, builder
, call
, callable
,
327 postInsertBlock
->getArguments());
329 /// Handle the terminators for each of the new blocks.
330 for (auto &newBlock
: newBlocks
)
331 interface
.handleTerminator(newBlock
.getTerminator(), postInsertBlock
);
334 // Splice the instructions of the inlined entry block into the insert block.
335 inlineBlock
->getOperations().splice(inlineBlock
->end(),
336 firstNewBlock
->getOperations());
337 firstNewBlock
->erase();
342 inlineRegionImpl(InlinerInterface
&interface
, Region
*src
, Block
*inlineBlock
,
343 Block::iterator inlinePoint
, ValueRange inlinedOperands
,
344 ValueRange resultsToReplace
, std::optional
<Location
> inlineLoc
,
345 bool shouldCloneInlinedRegion
, CallOpInterface call
= {}) {
346 // We expect the region to have at least one block.
350 auto *entryBlock
= &src
->front();
351 if (inlinedOperands
.size() != entryBlock
->getNumArguments())
354 // Map the provided call operands to the arguments of the region.
356 for (unsigned i
= 0, e
= inlinedOperands
.size(); i
!= e
; ++i
) {
357 // Verify that the types of the provided values match the function argument
359 BlockArgument regionArg
= entryBlock
->getArgument(i
);
360 if (inlinedOperands
[i
].getType() != regionArg
.getType())
362 mapper
.map(regionArg
, inlinedOperands
[i
]);
365 // Call into the main region inliner function.
366 return inlineRegionImpl(interface
, src
, inlineBlock
, inlinePoint
, mapper
,
367 resultsToReplace
, resultsToReplace
.getTypes(),
368 inlineLoc
, shouldCloneInlinedRegion
, call
);
371 LogicalResult
mlir::inlineRegion(InlinerInterface
&interface
, Region
*src
,
372 Operation
*inlinePoint
, IRMapping
&mapper
,
373 ValueRange resultsToReplace
,
374 TypeRange regionResultTypes
,
375 std::optional
<Location
> inlineLoc
,
376 bool shouldCloneInlinedRegion
) {
377 return inlineRegion(interface
, src
, inlinePoint
->getBlock(),
378 ++inlinePoint
->getIterator(), mapper
, resultsToReplace
,
379 regionResultTypes
, inlineLoc
, shouldCloneInlinedRegion
);
381 LogicalResult
mlir::inlineRegion(InlinerInterface
&interface
, Region
*src
,
383 Block::iterator inlinePoint
, IRMapping
&mapper
,
384 ValueRange resultsToReplace
,
385 TypeRange regionResultTypes
,
386 std::optional
<Location
> inlineLoc
,
387 bool shouldCloneInlinedRegion
) {
388 return inlineRegionImpl(interface
, src
, inlineBlock
, inlinePoint
, mapper
,
389 resultsToReplace
, regionResultTypes
, inlineLoc
,
390 shouldCloneInlinedRegion
);
393 LogicalResult
mlir::inlineRegion(InlinerInterface
&interface
, Region
*src
,
394 Operation
*inlinePoint
,
395 ValueRange inlinedOperands
,
396 ValueRange resultsToReplace
,
397 std::optional
<Location
> inlineLoc
,
398 bool shouldCloneInlinedRegion
) {
399 return inlineRegion(interface
, src
, inlinePoint
->getBlock(),
400 ++inlinePoint
->getIterator(), inlinedOperands
,
401 resultsToReplace
, inlineLoc
, shouldCloneInlinedRegion
);
403 LogicalResult
mlir::inlineRegion(InlinerInterface
&interface
, Region
*src
,
405 Block::iterator inlinePoint
,
406 ValueRange inlinedOperands
,
407 ValueRange resultsToReplace
,
408 std::optional
<Location
> inlineLoc
,
409 bool shouldCloneInlinedRegion
) {
410 return inlineRegionImpl(interface
, src
, inlineBlock
, inlinePoint
,
411 inlinedOperands
, resultsToReplace
, inlineLoc
,
412 shouldCloneInlinedRegion
);
415 /// Utility function used to generate a cast operation from the given interface,
416 /// or return nullptr if a cast could not be generated.
417 static Value
materializeConversion(const DialectInlinerInterface
*interface
,
418 SmallVectorImpl
<Operation
*> &castOps
,
419 OpBuilder
&castBuilder
, Value arg
, Type type
,
420 Location conversionLoc
) {
424 // Check to see if the interface for the call can materialize a conversion.
425 Operation
*castOp
= interface
->materializeCallConversion(castBuilder
, arg
,
426 type
, conversionLoc
);
429 castOps
.push_back(castOp
);
431 // Ensure that the generated cast is correct.
432 assert(castOp
->getNumOperands() == 1 && castOp
->getOperand(0) == arg
&&
433 castOp
->getNumResults() == 1 && *castOp
->result_type_begin() == type
);
434 return castOp
->getResult(0);
437 /// This function inlines a given region, 'src', of a callable operation,
438 /// 'callable', into the location defined by the given call operation. This
439 /// function returns failure if inlining is not possible, success otherwise. On
440 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
441 /// corresponds to whether the source region should be cloned into the 'call' or
442 /// spliced directly.
443 LogicalResult
mlir::inlineCall(InlinerInterface
&interface
,
444 CallOpInterface call
,
445 CallableOpInterface callable
, Region
*src
,
446 bool shouldCloneInlinedRegion
) {
447 // We expect the region to have at least one block.
450 auto *entryBlock
= &src
->front();
451 ArrayRef
<Type
> callableResultTypes
= callable
.getResultTypes();
453 // Make sure that the number of arguments and results matchup between the call
455 SmallVector
<Value
, 8> callOperands(call
.getArgOperands());
456 SmallVector
<Value
, 8> callResults(call
->getResults());
457 if (callOperands
.size() != entryBlock
->getNumArguments() ||
458 callResults
.size() != callableResultTypes
.size())
461 // A set of cast operations generated to matchup the signature of the region
462 // with the signature of the call.
463 SmallVector
<Operation
*, 4> castOps
;
464 castOps
.reserve(callOperands
.size() + callResults
.size());
466 // Functor used to cleanup generated state on failure.
467 auto cleanupState
= [&] {
468 for (auto *op
: castOps
) {
469 op
->getResult(0).replaceAllUsesWith(op
->getOperand(0));
475 // Builder used for any conversion operations that need to be materialized.
476 OpBuilder
castBuilder(call
);
477 Location castLoc
= call
.getLoc();
478 const auto *callInterface
= interface
.getInterfaceFor(call
->getDialect());
480 // Map the provided call operands to the arguments of the region.
482 for (unsigned i
= 0, e
= callOperands
.size(); i
!= e
; ++i
) {
483 BlockArgument regionArg
= entryBlock
->getArgument(i
);
484 Value operand
= callOperands
[i
];
486 // If the call operand doesn't match the expected region argument, try to
488 Type regionArgType
= regionArg
.getType();
489 if (operand
.getType() != regionArgType
) {
490 if (!(operand
= materializeConversion(callInterface
, castOps
, castBuilder
,
491 operand
, regionArgType
, castLoc
)))
492 return cleanupState();
494 mapper
.map(regionArg
, operand
);
497 // Ensure that the resultant values of the call match the callable.
498 castBuilder
.setInsertionPointAfter(call
);
499 for (unsigned i
= 0, e
= callResults
.size(); i
!= e
; ++i
) {
500 Value callResult
= callResults
[i
];
501 if (callResult
.getType() == callableResultTypes
[i
])
504 // Generate a conversion that will produce the original type, so that the IR
505 // is still valid after the original call gets replaced.
507 materializeConversion(callInterface
, castOps
, castBuilder
, callResult
,
508 callResult
.getType(), castLoc
);
510 return cleanupState();
511 callResult
.replaceAllUsesWith(castResult
);
512 castResult
.getDefiningOp()->replaceUsesOfWith(castResult
, callResult
);
515 // Check that it is legal to inline the callable into the call.
516 if (!interface
.isLegalToInline(call
, callable
, shouldCloneInlinedRegion
))
517 return cleanupState();
519 // Attempt to inline the call.
520 if (failed(inlineRegionImpl(interface
, src
, call
->getBlock(),
521 ++call
->getIterator(), mapper
, callResults
,
522 callableResultTypes
, call
.getLoc(),
523 shouldCloneInlinedRegion
, call
)))
524 return cleanupState();