[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Transforms / Utils / InliningUtils.cpp
blob0db097d14cd3c72c19557b56921e03cf8dd7f837
1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Transforms/InliningUtils.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Interfaces/CallInterfaces.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include <optional>
24 #define DEBUG_TYPE "inlining"
26 using namespace mlir;
28 /// Remap all locations reachable from the inlined blocks with CallSiteLoc
29 /// locations with the provided caller location.
30 static void
31 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
32 Location callerLoc) {
33 DenseMap<Location, LocationAttr> mappedLocations;
34 auto remapLoc = [&](Location loc) {
35 auto [it, inserted] = mappedLocations.try_emplace(loc);
36 // Only query the attribute uniquer once per callsite attribute.
37 if (inserted) {
38 auto newLoc = CallSiteLoc::get(loc, callerLoc);
39 it->getSecond() = newLoc;
41 return it->second;
44 AttrTypeReplacer attrReplacer;
45 attrReplacer.addReplacement(
46 [&](LocationAttr loc) -> std::pair<LocationAttr, WalkResult> {
47 return {remapLoc(loc), WalkResult::skip()};
48 });
50 for (Block &block : inlinedBlocks) {
51 for (BlockArgument &arg : block.getArguments())
52 if (LocationAttr newLoc = remapLoc(arg.getLoc()))
53 arg.setLoc(newLoc);
55 for (Operation &op : block)
56 attrReplacer.recursivelyReplaceElementsIn(&op, /*replaceAttrs=*/false,
57 /*replaceLocs=*/true);
61 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
62 IRMapping &mapper) {
63 auto remapOperands = [&](Operation *op) {
64 for (auto &operand : op->getOpOperands())
65 if (auto mappedOp = mapper.lookupOrNull(operand.get()))
66 operand.set(mappedOp);
68 for (auto &block : inlinedBlocks)
69 block.walk(remapOperands);
72 //===----------------------------------------------------------------------===//
73 // InlinerInterface
74 //===----------------------------------------------------------------------===//
76 bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
77 bool wouldBeCloned) const {
78 if (auto *handler = getInterfaceFor(call))
79 return handler->isLegalToInline(call, callable, wouldBeCloned);
80 return false;
83 bool InlinerInterface::isLegalToInline(Region *dest, Region *src,
84 bool wouldBeCloned,
85 IRMapping &valueMapping) const {
86 if (auto *handler = getInterfaceFor(dest->getParentOp()))
87 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
88 return false;
91 bool InlinerInterface::isLegalToInline(Operation *op, Region *dest,
92 bool wouldBeCloned,
93 IRMapping &valueMapping) const {
94 if (auto *handler = getInterfaceFor(op))
95 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
96 return false;
99 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
100 auto *handler = getInterfaceFor(op);
101 return handler ? handler->shouldAnalyzeRecursively(op) : true;
104 /// Handle the given inlined terminator by replacing it with a new operation
105 /// as necessary.
106 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
107 auto *handler = getInterfaceFor(op);
108 assert(handler && "expected valid dialect handler");
109 handler->handleTerminator(op, newDest);
112 /// Handle the given inlined terminator by replacing it with a new operation
113 /// as necessary.
114 void InlinerInterface::handleTerminator(Operation *op,
115 ValueRange valuesToRepl) const {
116 auto *handler = getInterfaceFor(op);
117 assert(handler && "expected valid dialect handler");
118 handler->handleTerminator(op, valuesToRepl);
121 Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
122 Operation *callable, Value argument,
123 DictionaryAttr argumentAttrs) const {
124 auto *handler = getInterfaceFor(callable);
125 assert(handler && "expected valid dialect handler");
126 return handler->handleArgument(builder, call, callable, argument,
127 argumentAttrs);
130 Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
131 Operation *callable, Value result,
132 DictionaryAttr resultAttrs) const {
133 auto *handler = getInterfaceFor(callable);
134 assert(handler && "expected valid dialect handler");
135 return handler->handleResult(builder, call, callable, result, resultAttrs);
138 void InlinerInterface::processInlinedCallBlocks(
139 Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
140 auto *handler = getInterfaceFor(call);
141 assert(handler && "expected valid dialect handler");
142 handler->processInlinedCallBlocks(call, inlinedBlocks);
145 /// Utility to check that all of the operations within 'src' can be inlined.
146 static bool isLegalToInline(InlinerInterface &interface, Region *src,
147 Region *insertRegion, bool shouldCloneInlinedRegion,
148 IRMapping &valueMapping) {
149 for (auto &block : *src) {
150 for (auto &op : block) {
151 // Check this operation.
152 if (!interface.isLegalToInline(&op, insertRegion,
153 shouldCloneInlinedRegion, valueMapping)) {
154 LLVM_DEBUG({
155 llvm::dbgs() << "* Illegal to inline because of op: ";
156 op.dump();
158 return false;
160 // Check any nested regions.
161 if (interface.shouldAnalyzeRecursively(&op) &&
162 llvm::any_of(op.getRegions(), [&](Region &region) {
163 return !isLegalToInline(interface, &region, insertRegion,
164 shouldCloneInlinedRegion, valueMapping);
166 return false;
169 return true;
172 //===----------------------------------------------------------------------===//
173 // Inline Methods
174 //===----------------------------------------------------------------------===//
176 static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
177 CallOpInterface call,
178 CallableOpInterface callable,
179 IRMapping &mapper) {
180 // Unpack the argument attributes if there are any.
181 SmallVector<DictionaryAttr> argAttrs(
182 callable.getCallableRegion()->getNumArguments(),
183 builder.getDictionaryAttr({}));
184 if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
185 assert(arrayAttr.size() == argAttrs.size());
186 for (auto [idx, attr] : llvm::enumerate(arrayAttr))
187 argAttrs[idx] = cast<DictionaryAttr>(attr);
190 // Run the argument attribute handler for the given argument and attribute.
191 for (auto [blockArg, argAttr] :
192 llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
193 Value newArgument = interface.handleArgument(
194 builder, call, callable, mapper.lookup(blockArg), argAttr);
195 assert(newArgument.getType() == mapper.lookup(blockArg).getType() &&
196 "expected the argument type to not change");
198 // Update the mapping to point the new argument returned by the handler.
199 mapper.map(blockArg, newArgument);
203 static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
204 CallOpInterface call, CallableOpInterface callable,
205 ValueRange results) {
206 // Unpack the result attributes if there are any.
207 SmallVector<DictionaryAttr> resAttrs(results.size(),
208 builder.getDictionaryAttr({}));
209 if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
210 assert(arrayAttr.size() == resAttrs.size());
211 for (auto [idx, attr] : llvm::enumerate(arrayAttr))
212 resAttrs[idx] = cast<DictionaryAttr>(attr);
215 // Run the result attribute handler for the given result and attribute.
216 SmallVector<DictionaryAttr> resultAttributes;
217 for (auto [result, resAttr] : llvm::zip(results, resAttrs)) {
218 // Store the original result users before running the handler.
219 DenseSet<Operation *> resultUsers;
220 for (Operation *user : result.getUsers())
221 resultUsers.insert(user);
223 Value newResult =
224 interface.handleResult(builder, call, callable, result, resAttr);
225 assert(newResult.getType() == result.getType() &&
226 "expected the result type to not change");
228 // Replace the result uses except for the ones introduce by the handler.
229 result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
230 return resultUsers.count(operand.getOwner());
235 static LogicalResult
236 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
237 Block::iterator inlinePoint, IRMapping &mapper,
238 ValueRange resultsToReplace, TypeRange regionResultTypes,
239 std::optional<Location> inlineLoc,
240 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
241 assert(resultsToReplace.size() == regionResultTypes.size());
242 // We expect the region to have at least one block.
243 if (src->empty())
244 return failure();
246 // Check that all of the region arguments have been mapped.
247 auto *srcEntryBlock = &src->front();
248 if (llvm::any_of(srcEntryBlock->getArguments(),
249 [&](BlockArgument arg) { return !mapper.contains(arg); }))
250 return failure();
252 // Check that the operations within the source region are valid to inline.
253 Region *insertRegion = inlineBlock->getParent();
254 if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
255 mapper) ||
256 !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
257 mapper))
258 return failure();
260 // Run the argument attribute handler before inlining the callable region.
261 OpBuilder builder(inlineBlock, inlinePoint);
262 auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
263 if (call && callable)
264 handleArgumentImpl(interface, builder, call, callable, mapper);
266 // Check to see if the region is being cloned, or moved inline. In either
267 // case, move the new blocks after the 'insertBlock' to improve IR
268 // readability.
269 Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
270 if (shouldCloneInlinedRegion)
271 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
272 else
273 insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
274 src->getBlocks(), src->begin(),
275 src->end());
277 // Get the range of newly inserted blocks.
278 auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
279 postInsertBlock->getIterator());
280 Block *firstNewBlock = &*newBlocks.begin();
282 // Remap the locations of the inlined operations if a valid source location
283 // was provided.
284 if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
285 remapInlinedLocations(newBlocks, *inlineLoc);
287 // If the blocks were moved in-place, make sure to remap any necessary
288 // operands.
289 if (!shouldCloneInlinedRegion)
290 remapInlinedOperands(newBlocks, mapper);
292 // Process the newly inlined blocks.
293 if (call)
294 interface.processInlinedCallBlocks(call, newBlocks);
295 interface.processInlinedBlocks(newBlocks);
297 // Handle the case where only a single block was inlined.
298 if (std::next(newBlocks.begin()) == newBlocks.end()) {
299 // Run the result attribute handler on the terminator operands.
300 Operation *firstBlockTerminator = firstNewBlock->getTerminator();
301 builder.setInsertionPoint(firstBlockTerminator);
302 if (call && callable)
303 handleResultImpl(interface, builder, call, callable,
304 firstBlockTerminator->getOperands());
306 // Have the interface handle the terminator of this block.
307 interface.handleTerminator(firstBlockTerminator, resultsToReplace);
308 firstBlockTerminator->erase();
310 // Merge the post insert block into the cloned entry block.
311 firstNewBlock->getOperations().splice(firstNewBlock->end(),
312 postInsertBlock->getOperations());
313 postInsertBlock->erase();
314 } else {
315 // Otherwise, there were multiple blocks inlined. Add arguments to the post
316 // insertion block to represent the results to replace.
317 for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
318 resultToRepl.value().replaceAllUsesWith(
319 postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
320 resultToRepl.value().getLoc()));
323 // Run the result attribute handler on the post insertion block arguments.
324 builder.setInsertionPointToStart(postInsertBlock);
325 if (call && callable)
326 handleResultImpl(interface, builder, call, callable,
327 postInsertBlock->getArguments());
329 /// Handle the terminators for each of the new blocks.
330 for (auto &newBlock : newBlocks)
331 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
334 // Splice the instructions of the inlined entry block into the insert block.
335 inlineBlock->getOperations().splice(inlineBlock->end(),
336 firstNewBlock->getOperations());
337 firstNewBlock->erase();
338 return success();
341 static LogicalResult
342 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
343 Block::iterator inlinePoint, ValueRange inlinedOperands,
344 ValueRange resultsToReplace, std::optional<Location> inlineLoc,
345 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
346 // We expect the region to have at least one block.
347 if (src->empty())
348 return failure();
350 auto *entryBlock = &src->front();
351 if (inlinedOperands.size() != entryBlock->getNumArguments())
352 return failure();
354 // Map the provided call operands to the arguments of the region.
355 IRMapping mapper;
356 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
357 // Verify that the types of the provided values match the function argument
358 // types.
359 BlockArgument regionArg = entryBlock->getArgument(i);
360 if (inlinedOperands[i].getType() != regionArg.getType())
361 return failure();
362 mapper.map(regionArg, inlinedOperands[i]);
365 // Call into the main region inliner function.
366 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
367 resultsToReplace, resultsToReplace.getTypes(),
368 inlineLoc, shouldCloneInlinedRegion, call);
371 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
372 Operation *inlinePoint, IRMapping &mapper,
373 ValueRange resultsToReplace,
374 TypeRange regionResultTypes,
375 std::optional<Location> inlineLoc,
376 bool shouldCloneInlinedRegion) {
377 return inlineRegion(interface, src, inlinePoint->getBlock(),
378 ++inlinePoint->getIterator(), mapper, resultsToReplace,
379 regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
381 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
382 Block *inlineBlock,
383 Block::iterator inlinePoint, IRMapping &mapper,
384 ValueRange resultsToReplace,
385 TypeRange regionResultTypes,
386 std::optional<Location> inlineLoc,
387 bool shouldCloneInlinedRegion) {
388 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
389 resultsToReplace, regionResultTypes, inlineLoc,
390 shouldCloneInlinedRegion);
393 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
394 Operation *inlinePoint,
395 ValueRange inlinedOperands,
396 ValueRange resultsToReplace,
397 std::optional<Location> inlineLoc,
398 bool shouldCloneInlinedRegion) {
399 return inlineRegion(interface, src, inlinePoint->getBlock(),
400 ++inlinePoint->getIterator(), inlinedOperands,
401 resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
403 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
404 Block *inlineBlock,
405 Block::iterator inlinePoint,
406 ValueRange inlinedOperands,
407 ValueRange resultsToReplace,
408 std::optional<Location> inlineLoc,
409 bool shouldCloneInlinedRegion) {
410 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
411 inlinedOperands, resultsToReplace, inlineLoc,
412 shouldCloneInlinedRegion);
415 /// Utility function used to generate a cast operation from the given interface,
416 /// or return nullptr if a cast could not be generated.
417 static Value materializeConversion(const DialectInlinerInterface *interface,
418 SmallVectorImpl<Operation *> &castOps,
419 OpBuilder &castBuilder, Value arg, Type type,
420 Location conversionLoc) {
421 if (!interface)
422 return nullptr;
424 // Check to see if the interface for the call can materialize a conversion.
425 Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
426 type, conversionLoc);
427 if (!castOp)
428 return nullptr;
429 castOps.push_back(castOp);
431 // Ensure that the generated cast is correct.
432 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
433 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
434 return castOp->getResult(0);
437 /// This function inlines a given region, 'src', of a callable operation,
438 /// 'callable', into the location defined by the given call operation. This
439 /// function returns failure if inlining is not possible, success otherwise. On
440 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
441 /// corresponds to whether the source region should be cloned into the 'call' or
442 /// spliced directly.
443 LogicalResult mlir::inlineCall(InlinerInterface &interface,
444 CallOpInterface call,
445 CallableOpInterface callable, Region *src,
446 bool shouldCloneInlinedRegion) {
447 // We expect the region to have at least one block.
448 if (src->empty())
449 return failure();
450 auto *entryBlock = &src->front();
451 ArrayRef<Type> callableResultTypes = callable.getResultTypes();
453 // Make sure that the number of arguments and results matchup between the call
454 // and the region.
455 SmallVector<Value, 8> callOperands(call.getArgOperands());
456 SmallVector<Value, 8> callResults(call->getResults());
457 if (callOperands.size() != entryBlock->getNumArguments() ||
458 callResults.size() != callableResultTypes.size())
459 return failure();
461 // A set of cast operations generated to matchup the signature of the region
462 // with the signature of the call.
463 SmallVector<Operation *, 4> castOps;
464 castOps.reserve(callOperands.size() + callResults.size());
466 // Functor used to cleanup generated state on failure.
467 auto cleanupState = [&] {
468 for (auto *op : castOps) {
469 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
470 op->erase();
472 return failure();
475 // Builder used for any conversion operations that need to be materialized.
476 OpBuilder castBuilder(call);
477 Location castLoc = call.getLoc();
478 const auto *callInterface = interface.getInterfaceFor(call->getDialect());
480 // Map the provided call operands to the arguments of the region.
481 IRMapping mapper;
482 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
483 BlockArgument regionArg = entryBlock->getArgument(i);
484 Value operand = callOperands[i];
486 // If the call operand doesn't match the expected region argument, try to
487 // generate a cast.
488 Type regionArgType = regionArg.getType();
489 if (operand.getType() != regionArgType) {
490 if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
491 operand, regionArgType, castLoc)))
492 return cleanupState();
494 mapper.map(regionArg, operand);
497 // Ensure that the resultant values of the call match the callable.
498 castBuilder.setInsertionPointAfter(call);
499 for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
500 Value callResult = callResults[i];
501 if (callResult.getType() == callableResultTypes[i])
502 continue;
504 // Generate a conversion that will produce the original type, so that the IR
505 // is still valid after the original call gets replaced.
506 Value castResult =
507 materializeConversion(callInterface, castOps, castBuilder, callResult,
508 callResult.getType(), castLoc);
509 if (!castResult)
510 return cleanupState();
511 callResult.replaceAllUsesWith(castResult);
512 castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
515 // Check that it is legal to inline the callable into the call.
516 if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
517 return cleanupState();
519 // Attempt to inline the call.
520 if (failed(inlineRegionImpl(interface, src, call->getBlock(),
521 ++call->getIterator(), mapper, callResults,
522 callableResultTypes, call.getLoc(),
523 shouldCloneInlinedRegion, call)))
524 return cleanupState();
525 return success();