[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase (#116934)
[llvm-project.git] / mlir / lib / Transforms / Utils / DialectConversion.cpp
blob03d483f73f255e69d73b4b26d010994092a7b0c5
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
16 #include "mlir/Interfaces/FunctionInterfaces.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
25 #include <optional>
27 using namespace mlir;
28 using namespace mlir::detail;
30 #define DEBUG_TYPE "dialect-conversion"
32 /// A utility function to log a successful result for the given reason.
33 template <typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
35 LLVM_DEBUG({
36 os.unindent();
37 os.startLine() << "} -> SUCCESS";
38 if (!fmt.empty())
39 os.getOStream() << " : "
40 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41 os.getOStream() << "\n";
42 });
45 /// A utility function to log a failure result for the given reason.
46 template <typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
48 LLVM_DEBUG({
49 os.unindent();
50 os.startLine() << "} -> FAILURE : "
51 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
52 << "\n";
53 });
56 /// Helper function that computes an insertion point where the given value is
57 /// defined and can be used without a dominance violation.
58 static OpBuilder::InsertPoint computeInsertPoint(Value value) {
59 Block *insertBlock = value.getParentBlock();
60 Block::iterator insertPt = insertBlock->begin();
61 if (OpResult inputRes = dyn_cast<OpResult>(value))
62 insertPt = ++inputRes.getOwner()->getIterator();
63 return OpBuilder::InsertPoint(insertBlock, insertPt);
66 //===----------------------------------------------------------------------===//
67 // ConversionValueMapping
68 //===----------------------------------------------------------------------===//
70 /// A list of replacement SSA values. Optimized for the common case of a single
71 /// SSA value.
72 using ReplacementValues = SmallVector<Value, 1>;
74 namespace {
75 /// This class wraps a IRMapping to provide recursive lookup
76 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
77 struct ConversionValueMapping {
78 /// Return "true" if an SSA value is mapped to the given value. May return
79 /// false positives.
80 bool isMappedTo(Value value) const { return mappedTo.contains(value); }
82 /// Lookup the most recently mapped value with the desired type in the
83 /// mapping.
84 ///
85 /// Special cases:
86 /// - If the desired type is "null", simply return the most recently mapped
87 /// value.
88 /// - If there is no mapping to the desired type, also return the most
89 /// recently mapped value.
90 /// - If there is no mapping for the given value at all, return the given
91 /// value.
92 Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
94 /// Lookup a mapped value within the map, or return null if a mapping does not
95 /// exist. If a mapping exists, this follows the same behavior of
96 /// `lookupOrDefault`.
97 Value lookupOrNull(Value from, Type desiredType = nullptr) const;
99 /// Map a value to the one provided.
100 void map(Value oldVal, Value newVal) {
101 LLVM_DEBUG({
102 for (Value it = newVal; it; it = mapping.lookupOrNull(it))
103 assert(it != oldVal && "inserting cyclic mapping");
105 mapping.map(oldVal, newVal);
106 mappedTo.insert(newVal);
109 /// Drop the last mapping for the given value.
110 void erase(Value value) { mapping.erase(value); }
112 private:
113 /// Current value mappings.
114 IRMapping mapping;
116 /// All SSA values that are mapped to. May contain false positives.
117 DenseSet<Value> mappedTo;
119 } // namespace
121 Value ConversionValueMapping::lookupOrDefault(Value from,
122 Type desiredType) const {
123 // Try to find the deepest value that has the desired type. If there is no
124 // such value, simply return the deepest value.
125 Value desiredValue;
126 do {
127 if (!desiredType || from.getType() == desiredType)
128 desiredValue = from;
130 Value mappedValue = mapping.lookupOrNull(from);
131 if (!mappedValue)
132 break;
133 from = mappedValue;
134 } while (true);
136 // If the desired value was found use it, otherwise default to the leaf value.
137 return desiredValue ? desiredValue : from;
140 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
141 Value result = lookupOrDefault(from, desiredType);
142 if (result == from || (desiredType && result.getType() != desiredType))
143 return nullptr;
144 return result;
147 //===----------------------------------------------------------------------===//
148 // Rewriter and Translation State
149 //===----------------------------------------------------------------------===//
150 namespace {
151 /// This class contains a snapshot of the current conversion rewriter state.
152 /// This is useful when saving and undoing a set of rewrites.
153 struct RewriterState {
154 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
155 unsigned numReplacedOps)
156 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
157 numReplacedOps(numReplacedOps) {}
159 /// The current number of rewrites performed.
160 unsigned numRewrites;
162 /// The current number of ignored operations.
163 unsigned numIgnoredOperations;
165 /// The current number of replaced ops that are scheduled for erasure.
166 unsigned numReplacedOps;
169 //===----------------------------------------------------------------------===//
170 // IR rewrites
171 //===----------------------------------------------------------------------===//
173 /// An IR rewrite that can be committed (upon success) or rolled back (upon
174 /// failure).
176 /// The dialect conversion keeps track of IR modifications (requested by the
177 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
178 /// are directly applied to the IR as the rewriter API is used, some are applied
179 /// partially, and some are delayed until the `IRRewrite` objects are committed.
180 class IRRewrite {
181 public:
182 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
183 /// Enum values are ordered, so that they can be used in `classof`: first all
184 /// block rewrites, then all operation rewrites.
185 enum class Kind {
186 // Block rewrites
187 CreateBlock,
188 EraseBlock,
189 InlineBlock,
190 MoveBlock,
191 BlockTypeConversion,
192 ReplaceBlockArg,
193 // Operation rewrites
194 MoveOperation,
195 ModifyOperation,
196 ReplaceOperation,
197 CreateOperation,
198 UnresolvedMaterialization
201 virtual ~IRRewrite() = default;
203 /// Roll back the rewrite. Operations may be erased during rollback.
204 virtual void rollback() = 0;
206 /// Commit the rewrite. At this point, it is certain that the dialect
207 /// conversion will succeed. All IR modifications, except for operation/block
208 /// erasure, must be performed through the given rewriter.
210 /// Instead of erasing operations/blocks, they should merely be unlinked
211 /// commit phase and finally be erased during the cleanup phase. This is
212 /// because internal dialect conversion state (such as `mapping`) may still
213 /// be using them.
215 /// Any IR modification that was already performed before the commit phase
216 /// (e.g., insertion of an op) must be communicated to the listener that may
217 /// be attached to the given rewriter.
218 virtual void commit(RewriterBase &rewriter) {}
220 /// Cleanup operations/blocks. Cleanup is called after commit.
221 virtual void cleanup(RewriterBase &rewriter) {}
223 Kind getKind() const { return kind; }
225 static bool classof(const IRRewrite *rewrite) { return true; }
227 protected:
228 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
229 : kind(kind), rewriterImpl(rewriterImpl) {}
231 const ConversionConfig &getConfig() const;
233 const Kind kind;
234 ConversionPatternRewriterImpl &rewriterImpl;
237 /// A block rewrite.
238 class BlockRewrite : public IRRewrite {
239 public:
240 /// Return the block that this rewrite operates on.
241 Block *getBlock() const { return block; }
243 static bool classof(const IRRewrite *rewrite) {
244 return rewrite->getKind() >= Kind::CreateBlock &&
245 rewrite->getKind() <= Kind::ReplaceBlockArg;
248 protected:
249 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
250 Block *block)
251 : IRRewrite(kind, rewriterImpl), block(block) {}
253 // The block that this rewrite operates on.
254 Block *block;
257 /// Creation of a block. Block creations are immediately reflected in the IR.
258 /// There is no extra work to commit the rewrite. During rollback, the newly
259 /// created block is erased.
260 class CreateBlockRewrite : public BlockRewrite {
261 public:
262 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
263 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
265 static bool classof(const IRRewrite *rewrite) {
266 return rewrite->getKind() == Kind::CreateBlock;
269 void commit(RewriterBase &rewriter) override {
270 // The block was already created and inserted. Just inform the listener.
271 if (auto *listener = rewriter.getListener())
272 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
275 void rollback() override {
276 // Unlink all of the operations within this block, they will be deleted
277 // separately.
278 auto &blockOps = block->getOperations();
279 while (!blockOps.empty())
280 blockOps.remove(blockOps.begin());
281 block->dropAllUses();
282 if (block->getParent())
283 block->erase();
284 else
285 delete block;
289 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
290 /// blocks are immediately unlinked, but only erased during cleanup. This makes
291 /// it easier to rollback a block erasure: the block is simply inserted into its
292 /// original location.
293 class EraseBlockRewrite : public BlockRewrite {
294 public:
295 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
296 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
297 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
299 static bool classof(const IRRewrite *rewrite) {
300 return rewrite->getKind() == Kind::EraseBlock;
303 ~EraseBlockRewrite() override {
304 assert(!block &&
305 "rewrite was neither rolled back nor committed/cleaned up");
308 void rollback() override {
309 // The block (owned by this rewrite) was not actually erased yet. It was
310 // just unlinked. Put it back into its original position.
311 assert(block && "expected block");
312 auto &blockList = region->getBlocks();
313 Region::iterator before = insertBeforeBlock
314 ? Region::iterator(insertBeforeBlock)
315 : blockList.end();
316 blockList.insert(before, block);
317 block = nullptr;
320 void commit(RewriterBase &rewriter) override {
321 // Erase the block.
322 assert(block && "expected block");
323 assert(block->empty() && "expected empty block");
325 // Notify the listener that the block is about to be erased.
326 if (auto *listener =
327 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
328 listener->notifyBlockErased(block);
331 void cleanup(RewriterBase &rewriter) override {
332 // Erase the block.
333 block->dropAllDefinedValueUses();
334 delete block;
335 block = nullptr;
338 private:
339 // The region in which this block was previously contained.
340 Region *region;
342 // The original successor of this block before it was unlinked. "nullptr" if
343 // this block was the only block in the region.
344 Block *insertBeforeBlock;
347 /// Inlining of a block. This rewrite is immediately reflected in the IR.
348 /// Note: This rewrite represents only the inlining of the operations. The
349 /// erasure of the inlined block is a separate rewrite.
350 class InlineBlockRewrite : public BlockRewrite {
351 public:
352 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
353 Block *sourceBlock, Block::iterator before)
354 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
355 sourceBlock(sourceBlock),
356 firstInlinedInst(sourceBlock->empty() ? nullptr
357 : &sourceBlock->front()),
358 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
359 // If a listener is attached to the dialect conversion, ops must be moved
360 // one-by-one. When they are moved in bulk, notifications cannot be sent
361 // because the ops that used to be in the source block at the time of the
362 // inlining (before the "commit" phase) are unknown at the time when
363 // notifications are sent (which is during the "commit" phase).
364 assert(!getConfig().listener &&
365 "InlineBlockRewrite not supported if listener is attached");
368 static bool classof(const IRRewrite *rewrite) {
369 return rewrite->getKind() == Kind::InlineBlock;
372 void rollback() override {
373 // Put the operations from the destination block (owned by the rewrite)
374 // back into the source block.
375 if (firstInlinedInst) {
376 assert(lastInlinedInst && "expected operation");
377 sourceBlock->getOperations().splice(sourceBlock->begin(),
378 block->getOperations(),
379 Block::iterator(firstInlinedInst),
380 ++Block::iterator(lastInlinedInst));
384 private:
385 // The block that originally contained the operations.
386 Block *sourceBlock;
388 // The first inlined operation.
389 Operation *firstInlinedInst;
391 // The last inlined operation.
392 Operation *lastInlinedInst;
395 /// Moving of a block. This rewrite is immediately reflected in the IR.
396 class MoveBlockRewrite : public BlockRewrite {
397 public:
398 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
399 Region *region, Block *insertBeforeBlock)
400 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
401 insertBeforeBlock(insertBeforeBlock) {}
403 static bool classof(const IRRewrite *rewrite) {
404 return rewrite->getKind() == Kind::MoveBlock;
407 void commit(RewriterBase &rewriter) override {
408 // The block was already moved. Just inform the listener.
409 if (auto *listener = rewriter.getListener()) {
410 // Note: `previousIt` cannot be passed because this is a delayed
411 // notification and iterators into past IR state cannot be represented.
412 listener->notifyBlockInserted(block, /*previous=*/region,
413 /*previousIt=*/{});
417 void rollback() override {
418 // Move the block back to its original position.
419 Region::iterator before =
420 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
421 region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
424 private:
425 // The region in which this block was previously contained.
426 Region *region;
428 // The original successor of this block before it was moved. "nullptr" if
429 // this block was the only block in the region.
430 Block *insertBeforeBlock;
433 /// Block type conversion. This rewrite is partially reflected in the IR.
434 class BlockTypeConversionRewrite : public BlockRewrite {
435 public:
436 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
437 Block *block, Block *origBlock)
438 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
439 origBlock(origBlock) {}
441 static bool classof(const IRRewrite *rewrite) {
442 return rewrite->getKind() == Kind::BlockTypeConversion;
445 Block *getOrigBlock() const { return origBlock; }
447 void commit(RewriterBase &rewriter) override;
449 void rollback() override;
451 private:
452 /// The original block that was requested to have its signature converted.
453 Block *origBlock;
456 /// Replacing a block argument. This rewrite is not immediately reflected in the
457 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
458 /// until the rewrite is committed.
459 class ReplaceBlockArgRewrite : public BlockRewrite {
460 public:
461 ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
462 Block *block, BlockArgument arg,
463 const TypeConverter *converter)
464 : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
465 converter(converter) {}
467 static bool classof(const IRRewrite *rewrite) {
468 return rewrite->getKind() == Kind::ReplaceBlockArg;
471 void commit(RewriterBase &rewriter) override;
473 void rollback() override;
475 private:
476 BlockArgument arg;
478 /// The current type converter when the block argument was replaced.
479 const TypeConverter *converter;
482 /// An operation rewrite.
483 class OperationRewrite : public IRRewrite {
484 public:
485 /// Return the operation that this rewrite operates on.
486 Operation *getOperation() const { return op; }
488 static bool classof(const IRRewrite *rewrite) {
489 return rewrite->getKind() >= Kind::MoveOperation &&
490 rewrite->getKind() <= Kind::UnresolvedMaterialization;
493 protected:
494 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
495 Operation *op)
496 : IRRewrite(kind, rewriterImpl), op(op) {}
498 // The operation that this rewrite operates on.
499 Operation *op;
502 /// Moving of an operation. This rewrite is immediately reflected in the IR.
503 class MoveOperationRewrite : public OperationRewrite {
504 public:
505 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
506 Operation *op, Block *block, Operation *insertBeforeOp)
507 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
508 insertBeforeOp(insertBeforeOp) {}
510 static bool classof(const IRRewrite *rewrite) {
511 return rewrite->getKind() == Kind::MoveOperation;
514 void commit(RewriterBase &rewriter) override {
515 // The operation was already moved. Just inform the listener.
516 if (auto *listener = rewriter.getListener()) {
517 // Note: `previousIt` cannot be passed because this is a delayed
518 // notification and iterators into past IR state cannot be represented.
519 listener->notifyOperationInserted(
520 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
521 /*insertPt=*/{}));
525 void rollback() override {
526 // Move the operation back to its original position.
527 Block::iterator before =
528 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
529 block->getOperations().splice(before, op->getBlock()->getOperations(), op);
532 private:
533 // The block in which this operation was previously contained.
534 Block *block;
536 // The original successor of this operation before it was moved. "nullptr"
537 // if this operation was the only operation in the region.
538 Operation *insertBeforeOp;
541 /// In-place modification of an op. This rewrite is immediately reflected in
542 /// the IR. The previous state of the operation is stored in this object.
543 class ModifyOperationRewrite : public OperationRewrite {
544 public:
545 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
546 Operation *op)
547 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
548 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
549 operands(op->operand_begin(), op->operand_end()),
550 successors(op->successor_begin(), op->successor_end()) {
551 if (OpaqueProperties prop = op->getPropertiesStorage()) {
552 // Make a copy of the properties.
553 propertiesStorage = operator new(op->getPropertiesStorageSize());
554 OpaqueProperties propCopy(propertiesStorage);
555 name.initOpProperties(propCopy, /*init=*/prop);
559 static bool classof(const IRRewrite *rewrite) {
560 return rewrite->getKind() == Kind::ModifyOperation;
563 ~ModifyOperationRewrite() override {
564 assert(!propertiesStorage &&
565 "rewrite was neither committed nor rolled back");
568 void commit(RewriterBase &rewriter) override {
569 // Notify the listener that the operation was modified in-place.
570 if (auto *listener =
571 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
572 listener->notifyOperationModified(op);
574 if (propertiesStorage) {
575 OpaqueProperties propCopy(propertiesStorage);
576 // Note: The operation may have been erased in the mean time, so
577 // OperationName must be stored in this object.
578 name.destroyOpProperties(propCopy);
579 operator delete(propertiesStorage);
580 propertiesStorage = nullptr;
584 void rollback() override {
585 op->setLoc(loc);
586 op->setAttrs(attrs);
587 op->setOperands(operands);
588 for (const auto &it : llvm::enumerate(successors))
589 op->setSuccessor(it.value(), it.index());
590 if (propertiesStorage) {
591 OpaqueProperties propCopy(propertiesStorage);
592 op->copyProperties(propCopy);
593 name.destroyOpProperties(propCopy);
594 operator delete(propertiesStorage);
595 propertiesStorage = nullptr;
599 private:
600 OperationName name;
601 LocationAttr loc;
602 DictionaryAttr attrs;
603 SmallVector<Value, 8> operands;
604 SmallVector<Block *, 2> successors;
605 void *propertiesStorage = nullptr;
608 /// Replacing an operation. Erasing an operation is treated as a special case
609 /// with "null" replacements. This rewrite is not immediately reflected in the
610 /// IR. An internal IR mapping is updated, but values are not replaced and the
611 /// original op is not erased until the rewrite is committed.
612 class ReplaceOperationRewrite : public OperationRewrite {
613 public:
614 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
615 Operation *op, const TypeConverter *converter)
616 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
617 converter(converter) {}
619 static bool classof(const IRRewrite *rewrite) {
620 return rewrite->getKind() == Kind::ReplaceOperation;
623 void commit(RewriterBase &rewriter) override;
625 void rollback() override;
627 void cleanup(RewriterBase &rewriter) override;
629 private:
630 /// An optional type converter that can be used to materialize conversions
631 /// between the new and old values if necessary.
632 const TypeConverter *converter;
635 class CreateOperationRewrite : public OperationRewrite {
636 public:
637 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
638 Operation *op)
639 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
641 static bool classof(const IRRewrite *rewrite) {
642 return rewrite->getKind() == Kind::CreateOperation;
645 void commit(RewriterBase &rewriter) override {
646 // The operation was already created and inserted. Just inform the listener.
647 if (auto *listener = rewriter.getListener())
648 listener->notifyOperationInserted(op, /*previous=*/{});
651 void rollback() override;
654 /// The type of materialization.
655 enum MaterializationKind {
656 /// This materialization materializes a conversion for an illegal block
657 /// argument type, to the original one.
658 Argument,
660 /// This materialization materializes a conversion from an illegal type to a
661 /// legal one.
662 Target,
664 /// This materialization materializes a conversion from a legal type back to
665 /// an illegal one.
666 Source
669 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
670 /// op. Unresolved materializations are erased at the end of the dialect
671 /// conversion.
672 class UnresolvedMaterializationRewrite : public OperationRewrite {
673 public:
674 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
675 UnrealizedConversionCastOp op,
676 const TypeConverter *converter,
677 MaterializationKind kind, Type originalType);
679 static bool classof(const IRRewrite *rewrite) {
680 return rewrite->getKind() == Kind::UnresolvedMaterialization;
683 void rollback() override;
685 UnrealizedConversionCastOp getOperation() const {
686 return cast<UnrealizedConversionCastOp>(op);
689 /// Return the type converter of this materialization (which may be null).
690 const TypeConverter *getConverter() const {
691 return converterAndKind.getPointer();
694 /// Return the kind of this materialization.
695 MaterializationKind getMaterializationKind() const {
696 return converterAndKind.getInt();
699 /// Return the original type of the SSA value.
700 Type getOriginalType() const { return originalType; }
702 private:
703 /// The corresponding type converter to use when resolving this
704 /// materialization, and the kind of this materialization.
705 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
706 converterAndKind;
708 /// The original type of the SSA value. Only used for target
709 /// materializations.
710 Type originalType;
712 } // namespace
714 /// Return "true" if there is an operation rewrite that matches the specified
715 /// rewrite type and operation among the given rewrites.
716 template <typename RewriteTy, typename R>
717 static bool hasRewrite(R &&rewrites, Operation *op) {
718 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
719 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
720 return rewriteTy && rewriteTy->getOperation() == op;
724 //===----------------------------------------------------------------------===//
725 // ConversionPatternRewriterImpl
726 //===----------------------------------------------------------------------===//
727 namespace mlir {
728 namespace detail {
729 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
730 explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
731 const ConversionConfig &config)
732 : context(ctx), eraseRewriter(ctx), config(config) {}
734 //===--------------------------------------------------------------------===//
735 // State Management
736 //===--------------------------------------------------------------------===//
738 /// Return the current state of the rewriter.
739 RewriterState getCurrentState();
741 /// Apply all requested operation rewrites. This method is invoked when the
742 /// conversion process succeeds.
743 void applyRewrites();
745 /// Reset the state of the rewriter to a previously saved point.
746 void resetState(RewriterState state);
748 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
749 /// failure.
750 template <typename RewriteTy, typename... Args>
751 void appendRewrite(Args &&...args) {
752 rewrites.push_back(
753 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
756 /// Undo the rewrites (motions, splits) one by one in reverse order until
757 /// "numRewritesToKeep" rewrites remains.
758 void undoRewrites(unsigned numRewritesToKeep = 0);
760 /// Remap the given values to those with potentially different types. Returns
761 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
762 /// is the tag used when describing a value within a diagnostic, e.g.
763 /// "operand".
764 LogicalResult remapValues(StringRef valueDiagTag,
765 std::optional<Location> inputLoc,
766 PatternRewriter &rewriter, ValueRange values,
767 SmallVectorImpl<Value> &remapped);
769 /// Return "true" if the given operation is ignored, and does not need to be
770 /// converted.
771 bool isOpIgnored(Operation *op) const;
773 /// Return "true" if the given operation was replaced or erased.
774 bool wasOpReplaced(Operation *op) const;
776 //===--------------------------------------------------------------------===//
777 // Type Conversion
778 //===--------------------------------------------------------------------===//
780 /// Convert the types of block arguments within the given region.
781 FailureOr<Block *>
782 convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
783 const TypeConverter &converter,
784 TypeConverter::SignatureConversion *entryConversion);
786 /// Apply the given signature conversion on the given block. The new block
787 /// containing the updated signature is returned. If no conversions were
788 /// necessary, e.g. if the block has no arguments, `block` is returned.
789 /// `converter` is used to generate any necessary cast operations that
790 /// translate between the origin argument types and those specified in the
791 /// signature conversion.
792 Block *applySignatureConversion(
793 ConversionPatternRewriter &rewriter, Block *block,
794 const TypeConverter *converter,
795 TypeConverter::SignatureConversion &signatureConversion);
797 //===--------------------------------------------------------------------===//
798 // Materializations
799 //===--------------------------------------------------------------------===//
801 /// Build an unresolved materialization operation given an output type and set
802 /// of input operands.
803 Value buildUnresolvedMaterialization(MaterializationKind kind,
804 OpBuilder::InsertPoint ip, Location loc,
805 ValueRange inputs, Type outputType,
806 Type originalType,
807 const TypeConverter *converter);
809 /// Build an N:1 materialization for the given original value that was
810 /// replaced with the given replacement values.
812 /// This is a workaround around incomplete 1:N support in the dialect
813 /// conversion driver. The conversion mapping can store only 1:1 replacements
814 /// and the conversion patterns only support single Value replacements in the
815 /// adaptor, so N values must be converted back to a single value. This
816 /// function will be deleted when full 1:N support has been added.
818 /// This function inserts an argument materialization back to the original
819 /// type, followed by a target materialization to the legalized type (if
820 /// applicable).
821 void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
822 ValueRange replacements, Value originalValue,
823 const TypeConverter *converter);
825 /// Find a replacement value for the given SSA value in the conversion value
826 /// mapping. The replacement value must have the same type as the given SSA
827 /// value. If there is no replacement value with the correct type, find the
828 /// latest replacement value (regardless of the type) and build a source
829 /// materialization.
830 Value findOrBuildReplacementValue(Value value,
831 const TypeConverter *converter);
833 //===--------------------------------------------------------------------===//
834 // Rewriter Notification Hooks
835 //===--------------------------------------------------------------------===//
837 //// Notifies that an op was inserted.
838 void notifyOperationInserted(Operation *op,
839 OpBuilder::InsertPoint previous) override;
841 /// Notifies that an op is about to be replaced with the given values.
842 void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
844 /// Notifies that a block is about to be erased.
845 void notifyBlockIsBeingErased(Block *block);
847 /// Notifies that a block was inserted.
848 void notifyBlockInserted(Block *block, Region *previous,
849 Region::iterator previousIt) override;
851 /// Notifies that a block is being inlined into another block.
852 void notifyBlockBeingInlined(Block *block, Block *srcBlock,
853 Block::iterator before);
855 /// Notifies that a pattern match failed for the given reason.
856 void
857 notifyMatchFailure(Location loc,
858 function_ref<void(Diagnostic &)> reasonCallback) override;
860 //===--------------------------------------------------------------------===//
861 // IR Erasure
862 //===--------------------------------------------------------------------===//
864 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
865 /// operation or block is erased multiple times. This rewriter assumes that
866 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
867 struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
868 public:
869 SingleEraseRewriter(MLIRContext *context)
870 : RewriterBase(context, /*listener=*/this) {}
872 /// Erase the given op (unless it was already erased).
873 void eraseOp(Operation *op) override {
874 if (wasErased(op))
875 return;
876 op->dropAllUses();
877 RewriterBase::eraseOp(op);
880 /// Erase the given block (unless it was already erased).
881 void eraseBlock(Block *block) override {
882 if (wasErased(block))
883 return;
884 assert(block->empty() && "expected empty block");
885 block->dropAllDefinedValueUses();
886 RewriterBase::eraseBlock(block);
889 bool wasErased(void *ptr) const { return erased.contains(ptr); }
891 void notifyOperationErased(Operation *op) override { erased.insert(op); }
893 void notifyBlockErased(Block *block) override { erased.insert(block); }
895 private:
896 /// Pointers to all erased operations and blocks.
897 DenseSet<void *> erased;
900 //===--------------------------------------------------------------------===//
901 // State
902 //===--------------------------------------------------------------------===//
904 /// MLIR context.
905 MLIRContext *context;
907 /// A rewriter that keeps track of ops/block that were already erased and
908 /// skips duplicate op/block erasures. This rewriter is used during the
909 /// "cleanup" phase.
910 SingleEraseRewriter eraseRewriter;
912 // Mapping between replaced values that differ in type. This happens when
913 // replacing a value with one of a different type.
914 ConversionValueMapping mapping;
916 /// Ordered list of block operations (creations, splits, motions).
917 SmallVector<std::unique_ptr<IRRewrite>> rewrites;
919 /// A set of operations that should no longer be considered for legalization.
920 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
921 /// tracked separately.
922 SetVector<Operation *> ignoredOps;
924 /// A set of operations that were replaced/erased. Such ops are not erased
925 /// immediately but only when the dialect conversion succeeds. In the mean
926 /// time, they should no longer be considered for legalization and any attempt
927 /// to modify/access them is invalid rewriter API usage.
928 SetVector<Operation *> replacedOps;
930 /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
931 /// to the corresponding rewrite objects.
932 DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
933 unresolvedMaterializations;
935 /// The current type converter, or nullptr if no type converter is currently
936 /// active.
937 const TypeConverter *currentTypeConverter = nullptr;
939 /// A mapping of regions to type converters that should be used when
940 /// converting the arguments of blocks within that region.
941 DenseMap<Region *, const TypeConverter *> regionToConverter;
943 /// Dialect conversion configuration.
944 const ConversionConfig &config;
946 #ifndef NDEBUG
947 /// A set of operations that have pending updates. This tracking isn't
948 /// strictly necessary, and is thus only active during debug builds for extra
949 /// verification.
950 SmallPtrSet<Operation *, 1> pendingRootUpdates;
952 /// A logger used to emit diagnostics during the conversion process.
953 llvm::ScopedPrinter logger{llvm::dbgs()};
954 #endif
956 } // namespace detail
957 } // namespace mlir
959 const ConversionConfig &IRRewrite::getConfig() const {
960 return rewriterImpl.config;
963 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
964 // Inform the listener about all IR modifications that have already taken
965 // place: References to the original block have been replaced with the new
966 // block.
967 if (auto *listener =
968 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
969 for (Operation *op : block->getUsers())
970 listener->notifyOperationModified(op);
973 void BlockTypeConversionRewrite::rollback() {
974 block->replaceAllUsesWith(origBlock);
977 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
978 Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
979 if (!repl)
980 return;
982 if (isa<BlockArgument>(repl)) {
983 rewriter.replaceAllUsesWith(arg, repl);
984 return;
987 // If the replacement value is an operation, we check to make sure that we
988 // don't replace uses that are within the parent operation of the
989 // replacement value.
990 Operation *replOp = cast<OpResult>(repl).getOwner();
991 Block *replBlock = replOp->getBlock();
992 rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
993 Operation *user = operand.getOwner();
994 return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
998 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
1000 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1001 auto *listener =
1002 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1004 // Compute replacement values.
1005 SmallVector<Value> replacements =
1006 llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1007 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1010 // Notify the listener that the operation is about to be replaced.
1011 if (listener)
1012 listener->notifyOperationReplaced(op, replacements);
1014 // Replace all uses with the new values.
1015 for (auto [result, newValue] :
1016 llvm::zip_equal(op->getResults(), replacements))
1017 if (newValue)
1018 rewriter.replaceAllUsesWith(result, newValue);
1020 // The original op will be erased, so remove it from the set of unlegalized
1021 // ops.
1022 if (getConfig().unlegalizedOps)
1023 getConfig().unlegalizedOps->erase(op);
1025 // Notify the listener that the operation (and its nested operations) was
1026 // erased.
1027 if (listener) {
1028 op->walk<WalkOrder::PostOrder>(
1029 [&](Operation *op) { listener->notifyOperationErased(op); });
1032 // Do not erase the operation yet. It may still be referenced in `mapping`.
1033 // Just unlink it for now and erase it during cleanup.
1034 op->getBlock()->getOperations().remove(op);
1037 void ReplaceOperationRewrite::rollback() {
1038 for (auto result : op->getResults())
1039 rewriterImpl.mapping.erase(result);
1042 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1043 rewriter.eraseOp(op);
1046 void CreateOperationRewrite::rollback() {
1047 for (Region &region : op->getRegions()) {
1048 while (!region.getBlocks().empty())
1049 region.getBlocks().remove(region.getBlocks().begin());
1051 op->dropAllUses();
1052 op->erase();
1055 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1056 ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1057 const TypeConverter *converter, MaterializationKind kind, Type originalType)
1058 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1059 converterAndKind(converter, kind), originalType(originalType) {
1060 assert((!originalType || kind == MaterializationKind::Target) &&
1061 "original type is valid only for target materializations");
1062 rewriterImpl.unresolvedMaterializations[op] = this;
1065 void UnresolvedMaterializationRewrite::rollback() {
1066 if (getMaterializationKind() == MaterializationKind::Target) {
1067 for (Value input : op->getOperands())
1068 rewriterImpl.mapping.erase(input);
1070 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1071 op->erase();
1074 void ConversionPatternRewriterImpl::applyRewrites() {
1075 // Commit all rewrites.
1076 IRRewriter rewriter(context, config.listener);
1077 // Note: New rewrites may be added during the "commit" phase and the
1078 // `rewrites` vector may reallocate.
1079 for (size_t i = 0; i < rewrites.size(); ++i)
1080 rewrites[i]->commit(rewriter);
1082 // Clean up all rewrites.
1083 for (auto &rewrite : rewrites)
1084 rewrite->cleanup(eraseRewriter);
1087 //===----------------------------------------------------------------------===//
1088 // State Management
1090 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
1091 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1094 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
1095 // Undo any rewrites.
1096 undoRewrites(state.numRewrites);
1098 // Pop all of the recorded ignored operations that are no longer valid.
1099 while (ignoredOps.size() != state.numIgnoredOperations)
1100 ignoredOps.pop_back();
1102 while (replacedOps.size() != state.numReplacedOps)
1103 replacedOps.pop_back();
1106 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1107 for (auto &rewrite :
1108 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1109 rewrite->rollback();
1110 rewrites.resize(numRewritesToKeep);
1113 LogicalResult ConversionPatternRewriterImpl::remapValues(
1114 StringRef valueDiagTag, std::optional<Location> inputLoc,
1115 PatternRewriter &rewriter, ValueRange values,
1116 SmallVectorImpl<Value> &remapped) {
1117 remapped.reserve(llvm::size(values));
1119 for (const auto &it : llvm::enumerate(values)) {
1120 Value operand = it.value();
1121 Type origType = operand.getType();
1122 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1124 if (!currentTypeConverter) {
1125 // The current pattern does not have a type converter. I.e., it does not
1126 // distinguish between legal and illegal types. For each operand, simply
1127 // pass through the most recently mapped value.
1128 remapped.push_back(mapping.lookupOrDefault(operand));
1129 continue;
1132 // If there is no legal conversion, fail to match this pattern.
1133 SmallVector<Type, 1> legalTypes;
1134 if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1135 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1136 diag << "unable to convert type for " << valueDiagTag << " #"
1137 << it.index() << ", type was " << origType;
1139 return failure();
1142 if (legalTypes.size() != 1) {
1143 // TODO: Parts of the dialect conversion infrastructure do not support
1144 // 1->N type conversions yet. Therefore, if a type is converted to 0 or
1145 // multiple types, the only thing that we can do for now is passing
1146 // through the most recently mapped value. Fixing this requires
1147 // improvements to the `ConversionValueMapping` (to be able to store 1:N
1148 // mappings) and to the `ConversionPattern` adaptor handling (to be able
1149 // to pass multiple remapped values for a single operand to the adaptor).
1150 remapped.push_back(mapping.lookupOrDefault(operand));
1151 continue;
1154 // Handle 1->1 type conversions.
1155 Type desiredType = legalTypes.front();
1156 // Try to find a mapped value with the desired type. (Or the operand itself
1157 // if the value is not mapped at all.)
1158 Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1159 if (newOperand.getType() != desiredType) {
1160 // If the looked up value's type does not have the desired type, it means
1161 // that the value was replaced with a value of different type and no
1162 // source materialization was created yet.
1163 Value castValue = buildUnresolvedMaterialization(
1164 MaterializationKind::Target, computeInsertPoint(newOperand),
1165 operandLoc,
1166 /*inputs=*/newOperand, /*outputType=*/desiredType,
1167 /*originalType=*/origType, currentTypeConverter);
1168 mapping.map(newOperand, castValue);
1169 newOperand = castValue;
1171 remapped.push_back(newOperand);
1173 return success();
1176 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1177 // Check to see if this operation is ignored or was replaced.
1178 return replacedOps.count(op) || ignoredOps.count(op);
1181 bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
1182 // Check to see if this operation was replaced.
1183 return replacedOps.count(op);
1186 //===----------------------------------------------------------------------===//
1187 // Type Conversion
1189 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1190 ConversionPatternRewriter &rewriter, Region *region,
1191 const TypeConverter &converter,
1192 TypeConverter::SignatureConversion *entryConversion) {
1193 regionToConverter[region] = &converter;
1194 if (region->empty())
1195 return nullptr;
1197 // Convert the arguments of each non-entry block within the region.
1198 for (Block &block :
1199 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1200 // Compute the signature for the block with the provided converter.
1201 std::optional<TypeConverter::SignatureConversion> conversion =
1202 converter.convertBlockSignature(&block);
1203 if (!conversion)
1204 return failure();
1205 // Convert the block with the computed signature.
1206 applySignatureConversion(rewriter, &block, &converter, *conversion);
1209 // Convert the entry block. If an entry signature conversion was provided,
1210 // use that one. Otherwise, compute the signature with the type converter.
1211 if (entryConversion)
1212 return applySignatureConversion(rewriter, &region->front(), &converter,
1213 *entryConversion);
1214 std::optional<TypeConverter::SignatureConversion> conversion =
1215 converter.convertBlockSignature(&region->front());
1216 if (!conversion)
1217 return failure();
1218 return applySignatureConversion(rewriter, &region->front(), &converter,
1219 *conversion);
1222 Block *ConversionPatternRewriterImpl::applySignatureConversion(
1223 ConversionPatternRewriter &rewriter, Block *block,
1224 const TypeConverter *converter,
1225 TypeConverter::SignatureConversion &signatureConversion) {
1226 OpBuilder::InsertionGuard g(rewriter);
1228 // If no arguments are being changed or added, there is nothing to do.
1229 unsigned origArgCount = block->getNumArguments();
1230 auto convertedTypes = signatureConversion.getConvertedTypes();
1231 if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1232 return block;
1234 // Compute the locations of all block arguments in the new block.
1235 SmallVector<Location> newLocs(convertedTypes.size(),
1236 rewriter.getUnknownLoc());
1237 for (unsigned i = 0; i < origArgCount; ++i) {
1238 auto inputMap = signatureConversion.getInputMapping(i);
1239 if (!inputMap || inputMap->replacementValue)
1240 continue;
1241 Location origLoc = block->getArgument(i).getLoc();
1242 for (unsigned j = 0; j < inputMap->size; ++j)
1243 newLocs[inputMap->inputNo + j] = origLoc;
1246 // Insert a new block with the converted block argument types and move all ops
1247 // from the old block to the new block.
1248 Block *newBlock =
1249 rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1250 convertedTypes, newLocs);
1252 // If a listener is attached to the dialect conversion, ops cannot be moved
1253 // to the destination block in bulk ("fast path"). This is because at the time
1254 // the notifications are sent, it is unknown which ops were moved. Instead,
1255 // ops should be moved one-by-one ("slow path"), so that a separate
1256 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1257 // a bit more efficient, so we try to do that when possible.
1258 bool fastPath = !config.listener;
1259 if (fastPath) {
1260 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1261 newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1262 } else {
1263 while (!block->empty())
1264 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1267 // Replace all uses of the old block with the new block.
1268 block->replaceAllUsesWith(newBlock);
1270 for (unsigned i = 0; i != origArgCount; ++i) {
1271 BlockArgument origArg = block->getArgument(i);
1272 Type origArgType = origArg.getType();
1274 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1275 signatureConversion.getInputMapping(i);
1276 if (!inputMap) {
1277 // This block argument was dropped and no replacement value was provided.
1278 // Materialize a replacement value "out of thin air".
1279 Value repl = buildUnresolvedMaterialization(
1280 MaterializationKind::Source,
1281 OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1282 /*inputs=*/ValueRange(),
1283 /*outputType=*/origArgType, /*originalType=*/Type(), converter);
1284 mapping.map(origArg, repl);
1285 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1286 continue;
1289 if (Value repl = inputMap->replacementValue) {
1290 // This block argument was dropped and a replacement value was provided.
1291 assert(inputMap->size == 0 &&
1292 "invalid to provide a replacement value when the argument isn't "
1293 "dropped");
1294 mapping.map(origArg, repl);
1295 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1296 continue;
1299 // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1300 // dialect conversion. Therefore, we need an argument materialization to
1301 // turn the replacement block arguments into a single SSA value that can be
1302 // used as a replacement.
1303 auto replArgs =
1304 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1305 insertNTo1Materialization(
1306 OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1307 /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1308 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1311 appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
1313 // Erase the old block. (It is just unlinked for now and will be erased during
1314 // cleanup.)
1315 rewriter.eraseBlock(block);
1317 return newBlock;
1320 //===----------------------------------------------------------------------===//
1321 // Materializations
1322 //===----------------------------------------------------------------------===//
1324 /// Build an unresolved materialization operation given an output type and set
1325 /// of input operands.
1326 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1327 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1328 ValueRange inputs, Type outputType, Type originalType,
1329 const TypeConverter *converter) {
1330 assert((!originalType || kind == MaterializationKind::Target) &&
1331 "original type is valid only for target materializations");
1333 // Avoid materializing an unnecessary cast.
1334 if (inputs.size() == 1 && inputs.front().getType() == outputType)
1335 return inputs.front();
1337 // Create an unresolved materialization. We use a new OpBuilder to avoid
1338 // tracking the materialization like we do for other operations.
1339 OpBuilder builder(outputType.getContext());
1340 builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1341 auto convertOp =
1342 builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1343 appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1344 originalType);
1345 return convertOp.getResult(0);
1348 void ConversionPatternRewriterImpl::insertNTo1Materialization(
1349 OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
1350 Value originalValue, const TypeConverter *converter) {
1351 // Insert argument materialization back to the original type.
1352 Type originalType = originalValue.getType();
1353 Value argMat =
1354 buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc,
1355 /*inputs=*/replacements, originalType,
1356 /*originalType=*/Type(), converter);
1357 mapping.map(originalValue, argMat);
1359 // Insert target materialization to the legalized type.
1360 Type legalOutputType;
1361 if (converter) {
1362 legalOutputType = converter->convertType(originalType);
1363 } else if (replacements.size() == 1) {
1364 // When there is no type converter, assume that the replacement value
1365 // types are legal. This is reasonable to assume because they were
1366 // specified by the user.
1367 // FIXME: This won't work for 1->N conversions because multiple output
1368 // types are not supported in parts of the dialect conversion. In such a
1369 // case, we currently use the original value type.
1370 legalOutputType = replacements[0].getType();
1372 if (legalOutputType && legalOutputType != originalType) {
1373 Value targetMat = buildUnresolvedMaterialization(
1374 MaterializationKind::Target, computeInsertPoint(argMat), loc,
1375 /*inputs=*/argMat, /*outputType=*/legalOutputType,
1376 /*originalType=*/originalType, converter);
1377 mapping.map(argMat, targetMat);
1381 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
1382 Value value, const TypeConverter *converter) {
1383 // Find a replacement value with the same type.
1384 Value repl = mapping.lookupOrNull(value, value.getType());
1385 if (repl)
1386 return repl;
1388 // Check if the value is dead. No replacement value is needed in that case.
1389 // This is an approximate check that may have false negatives but does not
1390 // require computing and traversing an inverse mapping. (We may end up
1391 // building source materializations that are never used and that fold away.)
1392 if (llvm::all_of(value.getUsers(),
1393 [&](Operation *op) { return replacedOps.contains(op); }) &&
1394 !mapping.isMappedTo(value))
1395 return Value();
1397 // No replacement value was found. Get the latest replacement value
1398 // (regardless of the type) and build a source materialization to the
1399 // original type.
1400 repl = mapping.lookupOrNull(value);
1401 if (!repl) {
1402 // No replacement value is registered in the mapping. This means that the
1403 // value is dropped and no longer needed. (If the value were still needed,
1404 // a source materialization producing a replacement value "out of thin air"
1405 // would have already been created during `replaceOp` or
1406 // `applySignatureConversion`.)
1407 return Value();
1409 Value castValue = buildUnresolvedMaterialization(
1410 MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
1411 /*inputs=*/repl, /*outputType=*/value.getType(),
1412 /*originalType=*/Type(), converter);
1413 return castValue;
1416 //===----------------------------------------------------------------------===//
1417 // Rewriter Notification Hooks
1419 void ConversionPatternRewriterImpl::notifyOperationInserted(
1420 Operation *op, OpBuilder::InsertPoint previous) {
1421 LLVM_DEBUG({
1422 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1423 << ")\n";
1425 assert(!wasOpReplaced(op->getParentOp()) &&
1426 "attempting to insert into a block within a replaced/erased op");
1428 if (!previous.isSet()) {
1429 // This is a newly created op.
1430 appendRewrite<CreateOperationRewrite>(op);
1431 return;
1433 Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1434 ? nullptr
1435 : &*previous.getPoint();
1436 appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1439 void ConversionPatternRewriterImpl::notifyOpReplaced(
1440 Operation *op, ArrayRef<ReplacementValues> newValues) {
1441 assert(newValues.size() == op->getNumResults());
1442 assert(!ignoredOps.contains(op) && "operation was already replaced");
1444 // Check if replaced op is an unresolved materialization, i.e., an
1445 // unrealized_conversion_cast op that was created by the conversion driver.
1446 bool isUnresolvedMaterialization = false;
1447 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1448 if (unresolvedMaterializations.contains(castOp))
1449 isUnresolvedMaterialization = true;
1451 // Create mappings for each of the new result values.
1452 for (auto [n, result] : llvm::zip_equal(newValues, op->getResults())) {
1453 ReplacementValues repl = n;
1454 if (repl.empty()) {
1455 // This result was dropped and no replacement value was provided.
1456 if (isUnresolvedMaterialization) {
1457 // Do not create another materializations if we are erasing a
1458 // materialization.
1459 continue;
1462 // Materialize a replacement value "out of thin air".
1463 Value sourceMat = buildUnresolvedMaterialization(
1464 MaterializationKind::Source, computeInsertPoint(result),
1465 result.getLoc(), /*inputs=*/ValueRange(),
1466 /*outputType=*/result.getType(), /*originalType=*/Type(),
1467 currentTypeConverter);
1468 repl.push_back(sourceMat);
1469 } else {
1470 // Make sure that the user does not mess with unresolved materializations
1471 // that were inserted by the conversion driver. We keep track of these
1472 // ops in internal data structures. Erasing them must be allowed because
1473 // this can happen when the user is erasing an entire block (including
1474 // its body). But replacing them with another value should be forbidden
1475 // to avoid problems with the `mapping`.
1476 assert(!isUnresolvedMaterialization &&
1477 "attempting to replace an unresolved materialization");
1480 // Remap result to replacement value.
1481 if (repl.empty())
1482 continue;
1484 if (repl.size() == 1) {
1485 // Single replacement value: replace directly.
1486 mapping.map(result, repl.front());
1487 } else {
1488 // Multiple replacement values: insert N:1 materialization.
1489 insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
1490 /*replacements=*/repl, /*outputValue=*/result,
1491 currentTypeConverter);
1495 appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1496 // Mark this operation and all nested ops as replaced.
1497 op->walk([&](Operation *op) { replacedOps.insert(op); });
1500 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1501 appendRewrite<EraseBlockRewrite>(block);
1504 void ConversionPatternRewriterImpl::notifyBlockInserted(
1505 Block *block, Region *previous, Region::iterator previousIt) {
1506 assert(!wasOpReplaced(block->getParentOp()) &&
1507 "attempting to insert into a region within a replaced/erased op");
1508 LLVM_DEBUG(
1510 Operation *parent = block->getParentOp();
1511 if (parent) {
1512 logger.startLine() << "** Insert Block into : '" << parent->getName()
1513 << "'(" << parent << ")\n";
1514 } else {
1515 logger.startLine()
1516 << "** Insert Block into detached Region (nullptr parent op)'";
1520 if (!previous) {
1521 // This is a newly created block.
1522 appendRewrite<CreateBlockRewrite>(block);
1523 return;
1525 Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1526 appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1529 void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
1530 Block *block, Block *srcBlock, Block::iterator before) {
1531 appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1534 void ConversionPatternRewriterImpl::notifyMatchFailure(
1535 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1536 LLVM_DEBUG({
1537 Diagnostic diag(loc, DiagnosticSeverity::Remark);
1538 reasonCallback(diag);
1539 logger.startLine() << "** Failure : " << diag.str() << "\n";
1540 if (config.notifyCallback)
1541 config.notifyCallback(diag);
1545 //===----------------------------------------------------------------------===//
1546 // ConversionPatternRewriter
1547 //===----------------------------------------------------------------------===//
1549 ConversionPatternRewriter::ConversionPatternRewriter(
1550 MLIRContext *ctx, const ConversionConfig &config)
1551 : PatternRewriter(ctx),
1552 impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1553 setListener(impl.get());
1556 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
1558 void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
1559 assert(op && newOp && "expected non-null op");
1560 replaceOp(op, newOp->getResults());
1563 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1564 assert(op->getNumResults() == newValues.size() &&
1565 "incorrect # of replacement values");
1566 LLVM_DEBUG({
1567 impl->logger.startLine()
1568 << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1570 SmallVector<ReplacementValues> newVals(newValues.size());
1571 for (auto [index, val] : llvm::enumerate(newValues))
1572 if (val)
1573 newVals[index].push_back(val);
1574 impl->notifyOpReplaced(op, newVals);
1577 void ConversionPatternRewriter::replaceOpWithMultiple(
1578 Operation *op, ArrayRef<ValueRange> newValues) {
1579 assert(op->getNumResults() == newValues.size() &&
1580 "incorrect # of replacement values");
1581 LLVM_DEBUG({
1582 impl->logger.startLine()
1583 << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1585 SmallVector<ReplacementValues> newVals(newValues.size(), {});
1586 for (auto [index, val] : llvm::enumerate(newValues))
1587 llvm::append_range(newVals[index], val);
1588 impl->notifyOpReplaced(op, newVals);
1591 void ConversionPatternRewriter::eraseOp(Operation *op) {
1592 LLVM_DEBUG({
1593 impl->logger.startLine()
1594 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1596 SmallVector<ReplacementValues> nullRepls(op->getNumResults(), {});
1597 impl->notifyOpReplaced(op, nullRepls);
1600 void ConversionPatternRewriter::eraseBlock(Block *block) {
1601 assert(!impl->wasOpReplaced(block->getParentOp()) &&
1602 "attempting to erase a block within a replaced/erased op");
1604 // Mark all ops for erasure.
1605 for (Operation &op : *block)
1606 eraseOp(&op);
1608 // Unlink the block from its parent region. The block is kept in the rewrite
1609 // object and will be actually destroyed when rewrites are applied. This
1610 // allows us to keep the operations in the block live and undo the removal by
1611 // re-inserting the block.
1612 impl->notifyBlockIsBeingErased(block);
1613 block->getParent()->getBlocks().remove(block);
1616 Block *ConversionPatternRewriter::applySignatureConversion(
1617 Block *block, TypeConverter::SignatureConversion &conversion,
1618 const TypeConverter *converter) {
1619 assert(!impl->wasOpReplaced(block->getParentOp()) &&
1620 "attempting to apply a signature conversion to a block within a "
1621 "replaced/erased op");
1622 return impl->applySignatureConversion(*this, block, converter, conversion);
1625 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1626 Region *region, const TypeConverter &converter,
1627 TypeConverter::SignatureConversion *entryConversion) {
1628 assert(!impl->wasOpReplaced(region->getParentOp()) &&
1629 "attempting to apply a signature conversion to a block within a "
1630 "replaced/erased op");
1631 return impl->convertRegionTypes(*this, region, converter, entryConversion);
1634 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1635 Value to) {
1636 LLVM_DEBUG({
1637 Operation *parentOp = from.getOwner()->getParentOp();
1638 impl->logger.startLine() << "** Replace Argument : '" << from
1639 << "'(in region of '" << parentOp->getName()
1640 << "'(" << from.getOwner()->getParentOp() << ")\n";
1642 impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1643 impl->currentTypeConverter);
1644 impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1647 Value ConversionPatternRewriter::getRemappedValue(Value key) {
1648 SmallVector<Value> remappedValues;
1649 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1650 remappedValues)))
1651 return nullptr;
1652 return remappedValues.front();
1655 LogicalResult
1656 ConversionPatternRewriter::getRemappedValues(ValueRange keys,
1657 SmallVectorImpl<Value> &results) {
1658 if (keys.empty())
1659 return success();
1660 return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1661 results);
1664 void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
1665 Block::iterator before,
1666 ValueRange argValues) {
1667 #ifndef NDEBUG
1668 assert(argValues.size() == source->getNumArguments() &&
1669 "incorrect # of argument replacement values");
1670 assert(!impl->wasOpReplaced(source->getParentOp()) &&
1671 "attempting to inline a block from a replaced/erased op");
1672 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1673 "attempting to inline a block into a replaced/erased op");
1674 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1675 // The source block will be deleted, so it should not have any users (i.e.,
1676 // there should be no predecessors).
1677 assert(llvm::all_of(source->getUsers(), opIgnored) &&
1678 "expected 'source' to have no predecessors");
1679 #endif // NDEBUG
1681 // If a listener is attached to the dialect conversion, ops cannot be moved
1682 // to the destination block in bulk ("fast path"). This is because at the time
1683 // the notifications are sent, it is unknown which ops were moved. Instead,
1684 // ops should be moved one-by-one ("slow path"), so that a separate
1685 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1686 // a bit more efficient, so we try to do that when possible.
1687 bool fastPath = !impl->config.listener;
1689 if (fastPath)
1690 impl->notifyBlockBeingInlined(dest, source, before);
1692 // Replace all uses of block arguments.
1693 for (auto it : llvm::zip(source->getArguments(), argValues))
1694 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1696 if (fastPath) {
1697 // Move all ops at once.
1698 dest->getOperations().splice(before, source->getOperations());
1699 } else {
1700 // Move op by op.
1701 while (!source->empty())
1702 moveOpBefore(&source->front(), dest, before);
1705 // Erase the source block.
1706 eraseBlock(source);
1709 void ConversionPatternRewriter::startOpModification(Operation *op) {
1710 assert(!impl->wasOpReplaced(op) &&
1711 "attempting to modify a replaced/erased op");
1712 #ifndef NDEBUG
1713 impl->pendingRootUpdates.insert(op);
1714 #endif
1715 impl->appendRewrite<ModifyOperationRewrite>(op);
1718 void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
1719 assert(!impl->wasOpReplaced(op) &&
1720 "attempting to modify a replaced/erased op");
1721 PatternRewriter::finalizeOpModification(op);
1722 // There is nothing to do here, we only need to track the operation at the
1723 // start of the update.
1724 #ifndef NDEBUG
1725 assert(impl->pendingRootUpdates.erase(op) &&
1726 "operation did not have a pending in-place update");
1727 #endif
1730 void ConversionPatternRewriter::cancelOpModification(Operation *op) {
1731 #ifndef NDEBUG
1732 assert(impl->pendingRootUpdates.erase(op) &&
1733 "operation did not have a pending in-place update");
1734 #endif
1735 // Erase the last update for this operation.
1736 auto it = llvm::find_if(
1737 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1738 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1739 return modifyRewrite && modifyRewrite->getOperation() == op;
1741 assert(it != impl->rewrites.rend() && "no root update started on op");
1742 (*it)->rollback();
1743 int updateIdx = std::prev(impl->rewrites.rend()) - it;
1744 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1747 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1748 return *impl;
1751 //===----------------------------------------------------------------------===//
1752 // ConversionPattern
1753 //===----------------------------------------------------------------------===//
1755 LogicalResult
1756 ConversionPattern::matchAndRewrite(Operation *op,
1757 PatternRewriter &rewriter) const {
1758 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1759 auto &rewriterImpl = dialectRewriter.getImpl();
1761 // Track the current conversion pattern type converter in the rewriter.
1762 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1763 getTypeConverter());
1765 // Remap the operands of the operation.
1766 SmallVector<Value, 4> operands;
1767 if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1768 op->getOperands(), operands))) {
1769 return failure();
1771 return matchAndRewrite(op, operands, dialectRewriter);
1774 //===----------------------------------------------------------------------===//
1775 // OperationLegalizer
1776 //===----------------------------------------------------------------------===//
1778 namespace {
1779 /// A set of rewrite patterns that can be used to legalize a given operation.
1780 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1782 /// This class defines a recursive operation legalizer.
1783 class OperationLegalizer {
1784 public:
1785 using LegalizationAction = ConversionTarget::LegalizationAction;
1787 OperationLegalizer(const ConversionTarget &targetInfo,
1788 const FrozenRewritePatternSet &patterns,
1789 const ConversionConfig &config);
1791 /// Returns true if the given operation is known to be illegal on the target.
1792 bool isIllegal(Operation *op) const;
1794 /// Attempt to legalize the given operation. Returns success if the operation
1795 /// was legalized, failure otherwise.
1796 LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1798 /// Returns the conversion target in use by the legalizer.
1799 const ConversionTarget &getTarget() { return target; }
1801 private:
1802 /// Attempt to legalize the given operation by folding it.
1803 LogicalResult legalizeWithFold(Operation *op,
1804 ConversionPatternRewriter &rewriter);
1806 /// Attempt to legalize the given operation by applying a pattern. Returns
1807 /// success if the operation was legalized, failure otherwise.
1808 LogicalResult legalizeWithPattern(Operation *op,
1809 ConversionPatternRewriter &rewriter);
1811 /// Return true if the given pattern may be applied to the given operation,
1812 /// false otherwise.
1813 bool canApplyPattern(Operation *op, const Pattern &pattern,
1814 ConversionPatternRewriter &rewriter);
1816 /// Legalize the resultant IR after successfully applying the given pattern.
1817 LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1818 ConversionPatternRewriter &rewriter,
1819 RewriterState &curState);
1821 /// Legalizes the actions registered during the execution of a pattern.
1822 LogicalResult
1823 legalizePatternBlockRewrites(Operation *op,
1824 ConversionPatternRewriter &rewriter,
1825 ConversionPatternRewriterImpl &impl,
1826 RewriterState &state, RewriterState &newState);
1827 LogicalResult legalizePatternCreatedOperations(
1828 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1829 RewriterState &state, RewriterState &newState);
1830 LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1831 ConversionPatternRewriterImpl &impl,
1832 RewriterState &state,
1833 RewriterState &newState);
1835 //===--------------------------------------------------------------------===//
1836 // Cost Model
1837 //===--------------------------------------------------------------------===//
1839 /// Build an optimistic legalization graph given the provided patterns. This
1840 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1841 /// patterns for operations that are not directly legal, but may be
1842 /// transitively legal for the current target given the provided patterns.
1843 void buildLegalizationGraph(
1844 LegalizationPatterns &anyOpLegalizerPatterns,
1845 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1847 /// Compute the benefit of each node within the computed legalization graph.
1848 /// This orders the patterns within 'legalizerPatterns' based upon two
1849 /// criteria:
1850 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1851 /// represent the more direct mapping to the target.
1852 /// 2) When comparing patterns with the same legalization depth, prefer the
1853 /// pattern with the highest PatternBenefit. This allows for users to
1854 /// prefer specific legalizations over others.
1855 void computeLegalizationGraphBenefit(
1856 LegalizationPatterns &anyOpLegalizerPatterns,
1857 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1859 /// Compute the legalization depth when legalizing an operation of the given
1860 /// type.
1861 unsigned computeOpLegalizationDepth(
1862 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1863 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1865 /// Apply the conversion cost model to the given set of patterns, and return
1866 /// the smallest legalization depth of any of the patterns. See
1867 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1868 unsigned applyCostModelToPatterns(
1869 LegalizationPatterns &patterns,
1870 DenseMap<OperationName, unsigned> &minOpPatternDepth,
1871 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1873 /// The current set of patterns that have been applied.
1874 SmallPtrSet<const Pattern *, 8> appliedPatterns;
1876 /// The legalization information provided by the target.
1877 const ConversionTarget &target;
1879 /// The pattern applicator to use for conversions.
1880 PatternApplicator applicator;
1882 /// Dialect conversion configuration.
1883 const ConversionConfig &config;
1885 } // namespace
1887 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1888 const FrozenRewritePatternSet &patterns,
1889 const ConversionConfig &config)
1890 : target(targetInfo), applicator(patterns), config(config) {
1891 // The set of patterns that can be applied to illegal operations to transform
1892 // them into legal ones.
1893 DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
1894 LegalizationPatterns anyOpLegalizerPatterns;
1896 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1897 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1900 bool OperationLegalizer::isIllegal(Operation *op) const {
1901 return target.isIllegal(op);
1904 LogicalResult
1905 OperationLegalizer::legalize(Operation *op,
1906 ConversionPatternRewriter &rewriter) {
1907 #ifndef NDEBUG
1908 const char *logLineComment =
1909 "//===-------------------------------------------===//\n";
1911 auto &logger = rewriter.getImpl().logger;
1912 #endif
1913 LLVM_DEBUG({
1914 logger.getOStream() << "\n";
1915 logger.startLine() << logLineComment;
1916 logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1917 << op << ") {\n";
1918 logger.indent();
1920 // If the operation has no regions, just print it here.
1921 if (op->getNumRegions() == 0) {
1922 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1923 logger.getOStream() << "\n\n";
1927 // Check if this operation is legal on the target.
1928 if (auto legalityInfo = target.isLegal(op)) {
1929 LLVM_DEBUG({
1930 logSuccess(
1931 logger, "operation marked legal by the target{0}",
1932 legalityInfo->isRecursivelyLegal
1933 ? "; NOTE: operation is recursively legal; skipping internals"
1934 : "");
1935 logger.startLine() << logLineComment;
1938 // If this operation is recursively legal, mark its children as ignored so
1939 // that we don't consider them for legalization.
1940 if (legalityInfo->isRecursivelyLegal) {
1941 op->walk([&](Operation *nested) {
1942 if (op != nested)
1943 rewriter.getImpl().ignoredOps.insert(nested);
1947 return success();
1950 // Check to see if the operation is ignored and doesn't need to be converted.
1951 if (rewriter.getImpl().isOpIgnored(op)) {
1952 LLVM_DEBUG({
1953 logSuccess(logger, "operation marked 'ignored' during conversion");
1954 logger.startLine() << logLineComment;
1956 return success();
1959 // If the operation isn't legal, try to fold it in-place.
1960 // TODO: Should we always try to do this, even if the op is
1961 // already legal?
1962 if (succeeded(legalizeWithFold(op, rewriter))) {
1963 LLVM_DEBUG({
1964 logSuccess(logger, "operation was folded");
1965 logger.startLine() << logLineComment;
1967 return success();
1970 // Otherwise, we need to apply a legalization pattern to this operation.
1971 if (succeeded(legalizeWithPattern(op, rewriter))) {
1972 LLVM_DEBUG({
1973 logSuccess(logger, "");
1974 logger.startLine() << logLineComment;
1976 return success();
1979 LLVM_DEBUG({
1980 logFailure(logger, "no matched legalization pattern");
1981 logger.startLine() << logLineComment;
1983 return failure();
1986 LogicalResult
1987 OperationLegalizer::legalizeWithFold(Operation *op,
1988 ConversionPatternRewriter &rewriter) {
1989 auto &rewriterImpl = rewriter.getImpl();
1990 RewriterState curState = rewriterImpl.getCurrentState();
1992 LLVM_DEBUG({
1993 rewriterImpl.logger.startLine() << "* Fold {\n";
1994 rewriterImpl.logger.indent();
1997 // Try to fold the operation.
1998 SmallVector<Value, 2> replacementValues;
1999 rewriter.setInsertionPoint(op);
2000 if (failed(rewriter.tryFold(op, replacementValues))) {
2001 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2002 return failure();
2004 // An empty list of replacement values indicates that the fold was in-place.
2005 // As the operation changed, a new legalization needs to be attempted.
2006 if (replacementValues.empty())
2007 return legalize(op, rewriter);
2009 // Insert a replacement for 'op' with the folded replacement values.
2010 rewriter.replaceOp(op, replacementValues);
2012 // Recursively legalize any new constant operations.
2013 for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2014 i != e; ++i) {
2015 auto *createOp =
2016 dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2017 if (!createOp)
2018 continue;
2019 if (failed(legalize(createOp->getOperation(), rewriter))) {
2020 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2021 "failed to legalize generated constant '{0}'",
2022 createOp->getOperation()->getName()));
2023 rewriterImpl.resetState(curState);
2024 return failure();
2028 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2029 return success();
2032 LogicalResult
2033 OperationLegalizer::legalizeWithPattern(Operation *op,
2034 ConversionPatternRewriter &rewriter) {
2035 auto &rewriterImpl = rewriter.getImpl();
2037 // Functor that returns if the given pattern may be applied.
2038 auto canApply = [&](const Pattern &pattern) {
2039 bool canApply = canApplyPattern(op, pattern, rewriter);
2040 if (canApply && config.listener)
2041 config.listener->notifyPatternBegin(pattern, op);
2042 return canApply;
2045 // Functor that cleans up the rewriter state after a pattern failed to match.
2046 RewriterState curState = rewriterImpl.getCurrentState();
2047 auto onFailure = [&](const Pattern &pattern) {
2048 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2049 LLVM_DEBUG({
2050 logFailure(rewriterImpl.logger, "pattern failed to match");
2051 if (rewriterImpl.config.notifyCallback) {
2052 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2053 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2054 << "\" on op:\n"
2055 << *op;
2056 rewriterImpl.config.notifyCallback(diag);
2059 if (config.listener)
2060 config.listener->notifyPatternEnd(pattern, failure());
2061 rewriterImpl.resetState(curState);
2062 appliedPatterns.erase(&pattern);
2065 // Functor that performs additional legalization when a pattern is
2066 // successfully applied.
2067 auto onSuccess = [&](const Pattern &pattern) {
2068 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2069 auto result = legalizePatternResult(op, pattern, rewriter, curState);
2070 appliedPatterns.erase(&pattern);
2071 if (failed(result))
2072 rewriterImpl.resetState(curState);
2073 if (config.listener)
2074 config.listener->notifyPatternEnd(pattern, result);
2075 return result;
2078 // Try to match and rewrite a pattern on this operation.
2079 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2080 onSuccess);
2083 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2084 ConversionPatternRewriter &rewriter) {
2085 LLVM_DEBUG({
2086 auto &os = rewriter.getImpl().logger;
2087 os.getOStream() << "\n";
2088 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2089 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2090 os.getOStream() << ")' {\n";
2091 os.indent();
2094 // Ensure that we don't cycle by not allowing the same pattern to be
2095 // applied twice in the same recursion stack if it is not known to be safe.
2096 if (!pattern.hasBoundedRewriteRecursion() &&
2097 !appliedPatterns.insert(&pattern).second) {
2098 LLVM_DEBUG(
2099 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2100 return false;
2102 return true;
2105 LogicalResult
2106 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2107 ConversionPatternRewriter &rewriter,
2108 RewriterState &curState) {
2109 auto &impl = rewriter.getImpl();
2111 #ifndef NDEBUG
2112 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2113 // Check that the root was either replaced or updated in place.
2114 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2115 auto replacedRoot = [&] {
2116 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2118 auto updatedRootInPlace = [&] {
2119 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2121 assert((replacedRoot() || updatedRootInPlace()) &&
2122 "expected pattern to replace the root operation");
2123 #endif // NDEBUG
2125 // Legalize each of the actions registered during application.
2126 RewriterState newState = impl.getCurrentState();
2127 if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2128 newState)) ||
2129 failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2130 failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2131 newState))) {
2132 return failure();
2135 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2136 return success();
2139 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2140 Operation *op, ConversionPatternRewriter &rewriter,
2141 ConversionPatternRewriterImpl &impl, RewriterState &state,
2142 RewriterState &newState) {
2143 SmallPtrSet<Operation *, 16> operationsToIgnore;
2145 // If the pattern moved or created any blocks, make sure the types of block
2146 // arguments get legalized.
2147 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2148 BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2149 if (!rewrite)
2150 continue;
2151 Block *block = rewrite->getBlock();
2152 if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2153 ReplaceBlockArgRewrite>(rewrite))
2154 continue;
2155 // Only check blocks outside of the current operation.
2156 Operation *parentOp = block->getParentOp();
2157 if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2158 continue;
2160 // If the region of the block has a type converter, try to convert the block
2161 // directly.
2162 if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2163 std::optional<TypeConverter::SignatureConversion> conversion =
2164 converter->convertBlockSignature(block);
2165 if (!conversion) {
2166 LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2167 "block"));
2168 return failure();
2170 impl.applySignatureConversion(rewriter, block, converter, *conversion);
2171 continue;
2174 // Otherwise, check that this operation isn't one generated by this pattern.
2175 // This is because we will attempt to legalize the parent operation, and
2176 // blocks in regions created by this pattern will already be legalized later
2177 // on. If we haven't built the set yet, build it now.
2178 if (operationsToIgnore.empty()) {
2179 for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2180 ++i) {
2181 auto *createOp =
2182 dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2183 if (!createOp)
2184 continue;
2185 operationsToIgnore.insert(createOp->getOperation());
2189 // If this operation should be considered for re-legalization, try it.
2190 if (operationsToIgnore.insert(parentOp).second &&
2191 failed(legalize(parentOp, rewriter))) {
2192 LLVM_DEBUG(logFailure(impl.logger,
2193 "operation '{0}'({1}) became illegal after rewrite",
2194 parentOp->getName(), parentOp));
2195 return failure();
2198 return success();
2201 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2202 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2203 RewriterState &state, RewriterState &newState) {
2204 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2205 auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2206 if (!createOp)
2207 continue;
2208 Operation *op = createOp->getOperation();
2209 if (failed(legalize(op, rewriter))) {
2210 LLVM_DEBUG(logFailure(impl.logger,
2211 "failed to legalize generated operation '{0}'({1})",
2212 op->getName(), op));
2213 return failure();
2216 return success();
2219 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2220 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2221 RewriterState &state, RewriterState &newState) {
2222 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2223 auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2224 if (!rewrite)
2225 continue;
2226 Operation *op = rewrite->getOperation();
2227 if (failed(legalize(op, rewriter))) {
2228 LLVM_DEBUG(logFailure(
2229 impl.logger, "failed to legalize operation updated in-place '{0}'",
2230 op->getName()));
2231 return failure();
2234 return success();
2237 //===----------------------------------------------------------------------===//
2238 // Cost Model
2240 void OperationLegalizer::buildLegalizationGraph(
2241 LegalizationPatterns &anyOpLegalizerPatterns,
2242 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2243 // A mapping between an operation and a set of operations that can be used to
2244 // generate it.
2245 DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
2246 // A mapping between an operation and any currently invalid patterns it has.
2247 DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
2248 // A worklist of patterns to consider for legality.
2249 SetVector<const Pattern *> patternWorklist;
2251 // Build the mapping from operations to the parent ops that may generate them.
2252 applicator.walkAllPatterns([&](const Pattern &pattern) {
2253 std::optional<OperationName> root = pattern.getRootKind();
2255 // If the pattern has no specific root, we can't analyze the relationship
2256 // between the root op and generated operations. Given that, add all such
2257 // patterns to the legalization set.
2258 if (!root) {
2259 anyOpLegalizerPatterns.push_back(&pattern);
2260 return;
2263 // Skip operations that are always known to be legal.
2264 if (target.getOpAction(*root) == LegalizationAction::Legal)
2265 return;
2267 // Add this pattern to the invalid set for the root op and record this root
2268 // as a parent for any generated operations.
2269 invalidPatterns[*root].insert(&pattern);
2270 for (auto op : pattern.getGeneratedOps())
2271 parentOps[op].insert(*root);
2273 // Add this pattern to the worklist.
2274 patternWorklist.insert(&pattern);
2277 // If there are any patterns that don't have a specific root kind, we can't
2278 // make direct assumptions about what operations will never be legalized.
2279 // Note: Technically we could, but it would require an analysis that may
2280 // recurse into itself. It would be better to perform this kind of filtering
2281 // at a higher level than here anyways.
2282 if (!anyOpLegalizerPatterns.empty()) {
2283 for (const Pattern *pattern : patternWorklist)
2284 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2285 return;
2288 while (!patternWorklist.empty()) {
2289 auto *pattern = patternWorklist.pop_back_val();
2291 // Check to see if any of the generated operations are invalid.
2292 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2293 std::optional<LegalizationAction> action = target.getOpAction(op);
2294 return !legalizerPatterns.count(op) &&
2295 (!action || action == LegalizationAction::Illegal);
2297 continue;
2299 // Otherwise, if all of the generated operation are valid, this op is now
2300 // legal so add all of the child patterns to the worklist.
2301 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2302 invalidPatterns[*pattern->getRootKind()].erase(pattern);
2304 // Add any invalid patterns of the parent operations to see if they have now
2305 // become legal.
2306 for (auto op : parentOps[*pattern->getRootKind()])
2307 patternWorklist.set_union(invalidPatterns[op]);
2311 void OperationLegalizer::computeLegalizationGraphBenefit(
2312 LegalizationPatterns &anyOpLegalizerPatterns,
2313 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2314 // The smallest pattern depth, when legalizing an operation.
2315 DenseMap<OperationName, unsigned> minOpPatternDepth;
2317 // For each operation that is transitively legal, compute a cost for it.
2318 for (auto &opIt : legalizerPatterns)
2319 if (!minOpPatternDepth.count(opIt.first))
2320 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2321 legalizerPatterns);
2323 // Apply the cost model to the patterns that can match any operation. Those
2324 // with a specific operation type are already resolved when computing the op
2325 // legalization depth.
2326 if (!anyOpLegalizerPatterns.empty())
2327 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2328 legalizerPatterns);
2330 // Apply a cost model to the pattern applicator. We order patterns first by
2331 // depth then benefit. `legalizerPatterns` contains per-op patterns by
2332 // decreasing benefit.
2333 applicator.applyCostModel([&](const Pattern &pattern) {
2334 ArrayRef<const Pattern *> orderedPatternList;
2335 if (std::optional<OperationName> rootName = pattern.getRootKind())
2336 orderedPatternList = legalizerPatterns[*rootName];
2337 else
2338 orderedPatternList = anyOpLegalizerPatterns;
2340 // If the pattern is not found, then it was removed and cannot be matched.
2341 auto *it = llvm::find(orderedPatternList, &pattern);
2342 if (it == orderedPatternList.end())
2343 return PatternBenefit::impossibleToMatch();
2345 // Patterns found earlier in the list have higher benefit.
2346 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2350 unsigned OperationLegalizer::computeOpLegalizationDepth(
2351 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2352 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2353 // Check for existing depth.
2354 auto depthIt = minOpPatternDepth.find(op);
2355 if (depthIt != minOpPatternDepth.end())
2356 return depthIt->second;
2358 // If a mapping for this operation does not exist, then this operation
2359 // is always legal. Return 0 as the depth for a directly legal operation.
2360 auto opPatternsIt = legalizerPatterns.find(op);
2361 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2362 return 0u;
2364 // Record this initial depth in case we encounter this op again when
2365 // recursively computing the depth.
2366 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2368 // Apply the cost model to the operation patterns, and update the minimum
2369 // depth.
2370 unsigned minDepth = applyCostModelToPatterns(
2371 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2372 minOpPatternDepth[op] = minDepth;
2373 return minDepth;
2376 unsigned OperationLegalizer::applyCostModelToPatterns(
2377 LegalizationPatterns &patterns,
2378 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2379 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2380 unsigned minDepth = std::numeric_limits<unsigned>::max();
2382 // Compute the depth for each pattern within the set.
2383 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2384 patternsByDepth.reserve(patterns.size());
2385 for (const Pattern *pattern : patterns) {
2386 unsigned depth = 1;
2387 for (auto generatedOp : pattern->getGeneratedOps()) {
2388 unsigned generatedOpDepth = computeOpLegalizationDepth(
2389 generatedOp, minOpPatternDepth, legalizerPatterns);
2390 depth = std::max(depth, generatedOpDepth + 1);
2392 patternsByDepth.emplace_back(pattern, depth);
2394 // Update the minimum depth of the pattern list.
2395 minDepth = std::min(minDepth, depth);
2398 // If the operation only has one legalization pattern, there is no need to
2399 // sort them.
2400 if (patternsByDepth.size() == 1)
2401 return minDepth;
2403 // Sort the patterns by those likely to be the most beneficial.
2404 std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2405 [](const std::pair<const Pattern *, unsigned> &lhs,
2406 const std::pair<const Pattern *, unsigned> &rhs) {
2407 // First sort by the smaller pattern legalization
2408 // depth.
2409 if (lhs.second != rhs.second)
2410 return lhs.second < rhs.second;
2412 // Then sort by the larger pattern benefit.
2413 auto lhsBenefit = lhs.first->getBenefit();
2414 auto rhsBenefit = rhs.first->getBenefit();
2415 return lhsBenefit > rhsBenefit;
2418 // Update the legalization pattern to use the new sorted list.
2419 patterns.clear();
2420 for (auto &patternIt : patternsByDepth)
2421 patterns.push_back(patternIt.first);
2422 return minDepth;
2425 //===----------------------------------------------------------------------===//
2426 // OperationConverter
2427 //===----------------------------------------------------------------------===//
2428 namespace {
2429 enum OpConversionMode {
2430 /// In this mode, the conversion will ignore failed conversions to allow
2431 /// illegal operations to co-exist in the IR.
2432 Partial,
2434 /// In this mode, all operations must be legal for the given target for the
2435 /// conversion to succeed.
2436 Full,
2438 /// In this mode, operations are analyzed for legality. No actual rewrites are
2439 /// applied to the operations on success.
2440 Analysis,
2442 } // namespace
2444 namespace mlir {
2445 // This class converts operations to a given conversion target via a set of
2446 // rewrite patterns. The conversion behaves differently depending on the
2447 // conversion mode.
2448 struct OperationConverter {
2449 explicit OperationConverter(const ConversionTarget &target,
2450 const FrozenRewritePatternSet &patterns,
2451 const ConversionConfig &config,
2452 OpConversionMode mode)
2453 : config(config), opLegalizer(target, patterns, this->config),
2454 mode(mode) {}
2456 /// Converts the given operations to the conversion target.
2457 LogicalResult convertOperations(ArrayRef<Operation *> ops);
2459 private:
2460 /// Converts an operation with the given rewriter.
2461 LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2463 /// Dialect conversion configuration.
2464 ConversionConfig config;
2466 /// The legalizer to use when converting operations.
2467 OperationLegalizer opLegalizer;
2469 /// The conversion mode to use when legalizing operations.
2470 OpConversionMode mode;
2472 } // namespace mlir
2474 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2475 Operation *op) {
2476 // Legalize the given operation.
2477 if (failed(opLegalizer.legalize(op, rewriter))) {
2478 // Handle the case of a failed conversion for each of the different modes.
2479 // Full conversions expect all operations to be converted.
2480 if (mode == OpConversionMode::Full)
2481 return op->emitError()
2482 << "failed to legalize operation '" << op->getName() << "'";
2483 // Partial conversions allow conversions to fail iff the operation was not
2484 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2485 // set, non-legalizable ops are added to that set.
2486 if (mode == OpConversionMode::Partial) {
2487 if (opLegalizer.isIllegal(op))
2488 return op->emitError()
2489 << "failed to legalize operation '" << op->getName()
2490 << "' that was explicitly marked illegal";
2491 if (config.unlegalizedOps)
2492 config.unlegalizedOps->insert(op);
2494 } else if (mode == OpConversionMode::Analysis) {
2495 // Analysis conversions don't fail if any operations fail to legalize,
2496 // they are only interested in the operations that were successfully
2497 // legalized.
2498 if (config.legalizableOps)
2499 config.legalizableOps->insert(op);
2501 return success();
2504 static LogicalResult
2505 legalizeUnresolvedMaterialization(RewriterBase &rewriter,
2506 UnresolvedMaterializationRewrite *rewrite) {
2507 UnrealizedConversionCastOp op = rewrite->getOperation();
2508 assert(!op.use_empty() &&
2509 "expected that dead materializations have already been DCE'd");
2510 Operation::operand_range inputOperands = op.getOperands();
2511 Type outputType = op.getResultTypes()[0];
2513 // Try to materialize the conversion.
2514 if (const TypeConverter *converter = rewrite->getConverter()) {
2515 rewriter.setInsertionPoint(op);
2516 Value newMaterialization;
2517 switch (rewrite->getMaterializationKind()) {
2518 case MaterializationKind::Argument:
2519 // Try to materialize an argument conversion.
2520 newMaterialization = converter->materializeArgumentConversion(
2521 rewriter, op->getLoc(), outputType, inputOperands);
2522 if (newMaterialization)
2523 break;
2524 // If an argument materialization failed, fallback to trying a target
2525 // materialization.
2526 [[fallthrough]];
2527 case MaterializationKind::Target:
2528 newMaterialization = converter->materializeTargetConversion(
2529 rewriter, op->getLoc(), outputType, inputOperands,
2530 rewrite->getOriginalType());
2531 break;
2532 case MaterializationKind::Source:
2533 newMaterialization = converter->materializeSourceConversion(
2534 rewriter, op->getLoc(), outputType, inputOperands);
2535 break;
2537 if (newMaterialization) {
2538 assert(newMaterialization.getType() == outputType &&
2539 "materialization callback produced value of incorrect type");
2540 rewriter.replaceOp(op, newMaterialization);
2541 return success();
2545 InFlightDiagnostic diag =
2546 op->emitError() << "failed to legalize unresolved materialization "
2547 "from ("
2548 << inputOperands.getTypes() << ") to (" << outputType
2549 << ") that remained live after conversion";
2550 diag.attachNote(op->getUsers().begin()->getLoc())
2551 << "see existing live user here: " << *op->getUsers().begin();
2552 return failure();
2555 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2556 if (ops.empty())
2557 return success();
2558 const ConversionTarget &target = opLegalizer.getTarget();
2560 // Compute the set of operations and blocks to convert.
2561 SmallVector<Operation *> toConvert;
2562 for (auto *op : ops) {
2563 op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
2564 [&](Operation *op) {
2565 toConvert.push_back(op);
2566 // Don't check this operation's children for conversion if the
2567 // operation is recursively legal.
2568 auto legalityInfo = target.isLegal(op);
2569 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2570 return WalkResult::skip();
2571 return WalkResult::advance();
2575 // Convert each operation and discard rewrites on failure.
2576 ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2577 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2579 for (auto *op : toConvert)
2580 if (failed(convert(rewriter, op)))
2581 return rewriterImpl.undoRewrites(), failure();
2583 // After a successful conversion, apply rewrites.
2584 rewriterImpl.applyRewrites();
2586 // Gather all unresolved materializations.
2587 SmallVector<UnrealizedConversionCastOp> allCastOps;
2588 const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
2589 &materializations = rewriterImpl.unresolvedMaterializations;
2590 for (auto it : materializations) {
2591 if (rewriterImpl.eraseRewriter.wasErased(it.first))
2592 continue;
2593 allCastOps.push_back(it.first);
2596 // Reconcile all UnrealizedConversionCastOps that were inserted by the
2597 // dialect conversion frameworks. (Not the one that were inserted by
2598 // patterns.)
2599 SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2600 reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2602 // Try to legalize all unresolved materializations.
2603 if (config.buildMaterializations) {
2604 IRRewriter rewriter(rewriterImpl.context, config.listener);
2605 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2606 auto it = materializations.find(castOp);
2607 assert(it != materializations.end() && "inconsistent state");
2608 if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2609 return failure();
2613 return success();
2616 //===----------------------------------------------------------------------===//
2617 // Reconcile Unrealized Casts
2618 //===----------------------------------------------------------------------===//
2620 void mlir::reconcileUnrealizedCasts(
2621 ArrayRef<UnrealizedConversionCastOp> castOps,
2622 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2623 SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
2624 castOps.end());
2625 // This set is maintained only if `remainingCastOps` is provided.
2626 DenseSet<Operation *> erasedOps;
2628 // Helper function that adds all operands to the worklist that are an
2629 // unrealized_conversion_cast op result.
2630 auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2631 for (Value v : castOp.getInputs())
2632 if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2633 worklist.insert(inputCastOp);
2636 // Helper function that return the unrealized_conversion_cast op that
2637 // defines all inputs of the given op (in the same order). Return "nullptr"
2638 // if there is no such op.
2639 auto getInputCast =
2640 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2641 if (castOp.getInputs().empty())
2642 return {};
2643 auto inputCastOp =
2644 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2645 if (!inputCastOp)
2646 return {};
2647 if (inputCastOp.getOutputs() != castOp.getInputs())
2648 return {};
2649 return inputCastOp;
2652 // Process ops in the worklist bottom-to-top.
2653 while (!worklist.empty()) {
2654 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2655 if (castOp->use_empty()) {
2656 // DCE: If the op has no users, erase it. Add the operands to the
2657 // worklist to find additional DCE opportunities.
2658 enqueueOperands(castOp);
2659 if (remainingCastOps)
2660 erasedOps.insert(castOp.getOperation());
2661 castOp->erase();
2662 continue;
2665 // Traverse the chain of input cast ops to see if an op with the same
2666 // input types can be found.
2667 UnrealizedConversionCastOp nextCast = castOp;
2668 while (nextCast) {
2669 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2670 // Found a cast where the input types match the output types of the
2671 // matched op. We can directly use those inputs and the matched op can
2672 // be removed.
2673 enqueueOperands(castOp);
2674 castOp.replaceAllUsesWith(nextCast.getInputs());
2675 if (remainingCastOps)
2676 erasedOps.insert(castOp.getOperation());
2677 castOp->erase();
2678 break;
2680 nextCast = getInputCast(nextCast);
2684 if (remainingCastOps)
2685 for (UnrealizedConversionCastOp op : castOps)
2686 if (!erasedOps.contains(op.getOperation()))
2687 remainingCastOps->push_back(op);
2690 //===----------------------------------------------------------------------===//
2691 // Type Conversion
2692 //===----------------------------------------------------------------------===//
2694 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
2695 ArrayRef<Type> types) {
2696 assert(!types.empty() && "expected valid types");
2697 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2698 addInputs(types);
2701 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
2702 assert(!types.empty() &&
2703 "1->0 type remappings don't need to be added explicitly");
2704 argTypes.append(types.begin(), types.end());
2707 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2708 unsigned newInputNo,
2709 unsigned newInputCount) {
2710 assert(!remappedInputs[origInputNo] && "input has already been remapped");
2711 assert(newInputCount != 0 && "expected valid input count");
2712 remappedInputs[origInputNo] =
2713 InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2716 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2717 Value replacementValue) {
2718 assert(!remappedInputs[origInputNo] && "input has already been remapped");
2719 remappedInputs[origInputNo] =
2720 InputMapping{origInputNo, /*size=*/0, replacementValue};
2723 LogicalResult TypeConverter::convertType(Type t,
2724 SmallVectorImpl<Type> &results) const {
2726 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2727 std::defer_lock);
2728 if (t.getContext()->isMultithreadingEnabled())
2729 cacheReadLock.lock();
2730 auto existingIt = cachedDirectConversions.find(t);
2731 if (existingIt != cachedDirectConversions.end()) {
2732 if (existingIt->second)
2733 results.push_back(existingIt->second);
2734 return success(existingIt->second != nullptr);
2736 auto multiIt = cachedMultiConversions.find(t);
2737 if (multiIt != cachedMultiConversions.end()) {
2738 results.append(multiIt->second.begin(), multiIt->second.end());
2739 return success();
2742 // Walk the added converters in reverse order to apply the most recently
2743 // registered first.
2744 size_t currentCount = results.size();
2746 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2747 std::defer_lock);
2749 for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2750 if (std::optional<LogicalResult> result = converter(t, results)) {
2751 if (t.getContext()->isMultithreadingEnabled())
2752 cacheWriteLock.lock();
2753 if (!succeeded(*result)) {
2754 cachedDirectConversions.try_emplace(t, nullptr);
2755 return failure();
2757 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2758 if (newTypes.size() == 1)
2759 cachedDirectConversions.try_emplace(t, newTypes.front());
2760 else
2761 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2762 return success();
2765 return failure();
2768 Type TypeConverter::convertType(Type t) const {
2769 // Use the multi-type result version to convert the type.
2770 SmallVector<Type, 1> results;
2771 if (failed(convertType(t, results)))
2772 return nullptr;
2774 // Check to ensure that only one type was produced.
2775 return results.size() == 1 ? results.front() : nullptr;
2778 LogicalResult
2779 TypeConverter::convertTypes(TypeRange types,
2780 SmallVectorImpl<Type> &results) const {
2781 for (Type type : types)
2782 if (failed(convertType(type, results)))
2783 return failure();
2784 return success();
2787 bool TypeConverter::isLegal(Type type) const {
2788 return convertType(type) == type;
2790 bool TypeConverter::isLegal(Operation *op) const {
2791 return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2794 bool TypeConverter::isLegal(Region *region) const {
2795 return llvm::all_of(*region, [this](Block &block) {
2796 return isLegal(block.getArgumentTypes());
2800 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2801 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2804 LogicalResult
2805 TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2806 SignatureConversion &result) const {
2807 // Try to convert the given input type.
2808 SmallVector<Type, 1> convertedTypes;
2809 if (failed(convertType(type, convertedTypes)))
2810 return failure();
2812 // If this argument is being dropped, there is nothing left to do.
2813 if (convertedTypes.empty())
2814 return success();
2816 // Otherwise, add the new inputs.
2817 result.addInputs(inputNo, convertedTypes);
2818 return success();
2820 LogicalResult
2821 TypeConverter::convertSignatureArgs(TypeRange types,
2822 SignatureConversion &result,
2823 unsigned origInputOffset) const {
2824 for (unsigned i = 0, e = types.size(); i != e; ++i)
2825 if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2826 return failure();
2827 return success();
2830 Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
2831 Location loc,
2832 Type resultType,
2833 ValueRange inputs) const {
2834 for (const MaterializationCallbackFn &fn :
2835 llvm::reverse(argumentMaterializations))
2836 if (Value result = fn(builder, resultType, inputs, loc))
2837 return result;
2838 return nullptr;
2841 Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
2842 Location loc, Type resultType,
2843 ValueRange inputs) const {
2844 for (const MaterializationCallbackFn &fn :
2845 llvm::reverse(sourceMaterializations))
2846 if (Value result = fn(builder, resultType, inputs, loc))
2847 return result;
2848 return nullptr;
2851 Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
2852 Location loc, Type resultType,
2853 ValueRange inputs,
2854 Type originalType) const {
2855 SmallVector<Value> result = materializeTargetConversion(
2856 builder, loc, TypeRange(resultType), inputs, originalType);
2857 if (result.empty())
2858 return nullptr;
2859 assert(result.size() == 1 && "expected single result");
2860 return result.front();
2863 SmallVector<Value> TypeConverter::materializeTargetConversion(
2864 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
2865 Type originalType) const {
2866 for (const TargetMaterializationCallbackFn &fn :
2867 llvm::reverse(targetMaterializations)) {
2868 SmallVector<Value> result =
2869 fn(builder, resultTypes, inputs, loc, originalType);
2870 if (result.empty())
2871 continue;
2872 assert(TypeRange(ValueRange(result)) == resultTypes &&
2873 "callback produced incorrect number of values or values with "
2874 "incorrect types");
2875 return result;
2877 return {};
2880 std::optional<TypeConverter::SignatureConversion>
2881 TypeConverter::convertBlockSignature(Block *block) const {
2882 SignatureConversion conversion(block->getNumArguments());
2883 if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2884 return std::nullopt;
2885 return conversion;
2888 //===----------------------------------------------------------------------===//
2889 // Type attribute conversion
2890 //===----------------------------------------------------------------------===//
2891 TypeConverter::AttributeConversionResult
2892 TypeConverter::AttributeConversionResult::result(Attribute attr) {
2893 return AttributeConversionResult(attr, resultTag);
2896 TypeConverter::AttributeConversionResult
2897 TypeConverter::AttributeConversionResult::na() {
2898 return AttributeConversionResult(nullptr, naTag);
2901 TypeConverter::AttributeConversionResult
2902 TypeConverter::AttributeConversionResult::abort() {
2903 return AttributeConversionResult(nullptr, abortTag);
2906 bool TypeConverter::AttributeConversionResult::hasResult() const {
2907 return impl.getInt() == resultTag;
2910 bool TypeConverter::AttributeConversionResult::isNa() const {
2911 return impl.getInt() == naTag;
2914 bool TypeConverter::AttributeConversionResult::isAbort() const {
2915 return impl.getInt() == abortTag;
2918 Attribute TypeConverter::AttributeConversionResult::getResult() const {
2919 assert(hasResult() && "Cannot get result from N/A or abort");
2920 return impl.getPointer();
2923 std::optional<Attribute>
2924 TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
2925 for (const TypeAttributeConversionCallbackFn &fn :
2926 llvm::reverse(typeAttributeConversions)) {
2927 AttributeConversionResult res = fn(type, attr);
2928 if (res.hasResult())
2929 return res.getResult();
2930 if (res.isAbort())
2931 return std::nullopt;
2933 return std::nullopt;
2936 //===----------------------------------------------------------------------===//
2937 // FunctionOpInterfaceSignatureConversion
2938 //===----------------------------------------------------------------------===//
2940 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
2941 const TypeConverter &typeConverter,
2942 ConversionPatternRewriter &rewriter) {
2943 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
2944 if (!type)
2945 return failure();
2947 // Convert the original function types.
2948 TypeConverter::SignatureConversion result(type.getNumInputs());
2949 SmallVector<Type, 1> newResults;
2950 if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
2951 failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
2952 failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
2953 typeConverter, &result)))
2954 return failure();
2956 // Update the function signature in-place.
2957 auto newType = FunctionType::get(rewriter.getContext(),
2958 result.getConvertedTypes(), newResults);
2960 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
2962 return success();
2965 /// Create a default conversion pattern that rewrites the type signature of a
2966 /// FunctionOpInterface op. This only supports ops which use FunctionType to
2967 /// represent their type.
2968 namespace {
2969 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
2970 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
2971 MLIRContext *ctx,
2972 const TypeConverter &converter)
2973 : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
2975 LogicalResult
2976 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
2977 ConversionPatternRewriter &rewriter) const override {
2978 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
2979 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
2983 struct AnyFunctionOpInterfaceSignatureConversion
2984 : public OpInterfaceConversionPattern<FunctionOpInterface> {
2985 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
2987 LogicalResult
2988 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
2989 ConversionPatternRewriter &rewriter) const override {
2990 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
2993 } // namespace
2995 FailureOr<Operation *>
2996 mlir::convertOpResultTypes(Operation *op, ValueRange operands,
2997 const TypeConverter &converter,
2998 ConversionPatternRewriter &rewriter) {
2999 assert(op && "Invalid op");
3000 Location loc = op->getLoc();
3001 if (converter.isLegal(op))
3002 return rewriter.notifyMatchFailure(loc, "op already legal");
3004 OperationState newOp(loc, op->getName());
3005 newOp.addOperands(operands);
3007 SmallVector<Type> newResultTypes;
3008 if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3009 return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3011 newOp.addTypes(newResultTypes);
3012 newOp.addAttributes(op->getAttrs());
3013 return rewriter.create(newOp);
3016 void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3017 StringRef functionLikeOpName, RewritePatternSet &patterns,
3018 const TypeConverter &converter) {
3019 patterns.add<FunctionOpInterfaceSignatureConversion>(
3020 functionLikeOpName, patterns.getContext(), converter);
3023 void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3024 RewritePatternSet &patterns, const TypeConverter &converter) {
3025 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3026 converter, patterns.getContext());
3029 //===----------------------------------------------------------------------===//
3030 // ConversionTarget
3031 //===----------------------------------------------------------------------===//
3033 void ConversionTarget::setOpAction(OperationName op,
3034 LegalizationAction action) {
3035 legalOperations[op].action = action;
3038 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3039 LegalizationAction action) {
3040 for (StringRef dialect : dialectNames)
3041 legalDialects[dialect] = action;
3044 auto ConversionTarget::getOpAction(OperationName op) const
3045 -> std::optional<LegalizationAction> {
3046 std::optional<LegalizationInfo> info = getOpInfo(op);
3047 return info ? info->action : std::optional<LegalizationAction>();
3050 auto ConversionTarget::isLegal(Operation *op) const
3051 -> std::optional<LegalOpDetails> {
3052 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3053 if (!info)
3054 return std::nullopt;
3056 // Returns true if this operation instance is known to be legal.
3057 auto isOpLegal = [&] {
3058 // Handle dynamic legality either with the provided legality function.
3059 if (info->action == LegalizationAction::Dynamic) {
3060 std::optional<bool> result = info->legalityFn(op);
3061 if (result)
3062 return *result;
3065 // Otherwise, the operation is only legal if it was marked 'Legal'.
3066 return info->action == LegalizationAction::Legal;
3068 if (!isOpLegal())
3069 return std::nullopt;
3071 // This operation is legal, compute any additional legality information.
3072 LegalOpDetails legalityDetails;
3073 if (info->isRecursivelyLegal) {
3074 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3075 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3076 legalityDetails.isRecursivelyLegal =
3077 legalityFnIt->second(op).value_or(true);
3078 } else {
3079 legalityDetails.isRecursivelyLegal = true;
3082 return legalityDetails;
3085 bool ConversionTarget::isIllegal(Operation *op) const {
3086 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3087 if (!info)
3088 return false;
3090 if (info->action == LegalizationAction::Dynamic) {
3091 std::optional<bool> result = info->legalityFn(op);
3092 if (!result)
3093 return false;
3095 return !(*result);
3098 return info->action == LegalizationAction::Illegal;
3101 static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
3102 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3103 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3104 if (!oldCallback)
3105 return newCallback;
3107 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3108 Operation *op) -> std::optional<bool> {
3109 if (std::optional<bool> result = newCl(op))
3110 return *result;
3112 return oldCl(op);
3114 return chain;
3117 void ConversionTarget::setLegalityCallback(
3118 OperationName name, const DynamicLegalityCallbackFn &callback) {
3119 assert(callback && "expected valid legality callback");
3120 auto *infoIt = legalOperations.find(name);
3121 assert(infoIt != legalOperations.end() &&
3122 infoIt->second.action == LegalizationAction::Dynamic &&
3123 "expected operation to already be marked as dynamically legal");
3124 infoIt->second.legalityFn =
3125 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3128 void ConversionTarget::markOpRecursivelyLegal(
3129 OperationName name, const DynamicLegalityCallbackFn &callback) {
3130 auto *infoIt = legalOperations.find(name);
3131 assert(infoIt != legalOperations.end() &&
3132 infoIt->second.action != LegalizationAction::Illegal &&
3133 "expected operation to already be marked as legal");
3134 infoIt->second.isRecursivelyLegal = true;
3135 if (callback)
3136 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3137 std::move(opRecursiveLegalityFns[name]), callback);
3138 else
3139 opRecursiveLegalityFns.erase(name);
3142 void ConversionTarget::setLegalityCallback(
3143 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3144 assert(callback && "expected valid legality callback");
3145 for (StringRef dialect : dialects)
3146 dialectLegalityFns[dialect] = composeLegalityCallbacks(
3147 std::move(dialectLegalityFns[dialect]), callback);
3150 void ConversionTarget::setLegalityCallback(
3151 const DynamicLegalityCallbackFn &callback) {
3152 assert(callback && "expected valid legality callback");
3153 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3156 auto ConversionTarget::getOpInfo(OperationName op) const
3157 -> std::optional<LegalizationInfo> {
3158 // Check for info for this specific operation.
3159 const auto *it = legalOperations.find(op);
3160 if (it != legalOperations.end())
3161 return it->second;
3162 // Check for info for the parent dialect.
3163 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3164 if (dialectIt != legalDialects.end()) {
3165 DynamicLegalityCallbackFn callback;
3166 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3167 if (dialectFn != dialectLegalityFns.end())
3168 callback = dialectFn->second;
3169 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3170 callback};
3172 // Otherwise, check if we mark unknown operations as dynamic.
3173 if (unknownLegalityFn)
3174 return LegalizationInfo{LegalizationAction::Dynamic,
3175 /*isRecursivelyLegal=*/false, unknownLegalityFn};
3176 return std::nullopt;
3179 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3180 //===----------------------------------------------------------------------===//
3181 // PDL Configuration
3182 //===----------------------------------------------------------------------===//
3184 void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
3185 auto &rewriterImpl =
3186 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3187 rewriterImpl.currentTypeConverter = getTypeConverter();
3190 void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
3191 auto &rewriterImpl =
3192 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3193 rewriterImpl.currentTypeConverter = nullptr;
3196 /// Remap the given value using the rewriter and the type converter in the
3197 /// provided config.
3198 static FailureOr<SmallVector<Value>>
3199 pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
3200 SmallVector<Value> mappedValues;
3201 if (failed(rewriter.getRemappedValues(values, mappedValues)))
3202 return failure();
3203 return std::move(mappedValues);
3206 void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
3207 patterns.getPDLPatterns().registerRewriteFunction(
3208 "convertValue",
3209 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3210 auto results = pdllConvertValues(
3211 static_cast<ConversionPatternRewriter &>(rewriter), value);
3212 if (failed(results))
3213 return failure();
3214 return results->front();
3216 patterns.getPDLPatterns().registerRewriteFunction(
3217 "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3218 return pdllConvertValues(
3219 static_cast<ConversionPatternRewriter &>(rewriter), values);
3221 patterns.getPDLPatterns().registerRewriteFunction(
3222 "convertType",
3223 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3224 auto &rewriterImpl =
3225 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3226 if (const TypeConverter *converter =
3227 rewriterImpl.currentTypeConverter) {
3228 if (Type newType = converter->convertType(type))
3229 return newType;
3230 return failure();
3232 return type;
3234 patterns.getPDLPatterns().registerRewriteFunction(
3235 "convertTypes",
3236 [](PatternRewriter &rewriter,
3237 TypeRange types) -> FailureOr<SmallVector<Type>> {
3238 auto &rewriterImpl =
3239 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3240 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3241 if (!converter)
3242 return SmallVector<Type>(types);
3244 SmallVector<Type> remappedTypes;
3245 if (failed(converter->convertTypes(types, remappedTypes)))
3246 return failure();
3247 return std::move(remappedTypes);
3250 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3252 //===----------------------------------------------------------------------===//
3253 // Op Conversion Entry Points
3254 //===----------------------------------------------------------------------===//
3256 //===----------------------------------------------------------------------===//
3257 // Partial Conversion
3259 LogicalResult mlir::applyPartialConversion(
3260 ArrayRef<Operation *> ops, const ConversionTarget &target,
3261 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3262 OperationConverter opConverter(target, patterns, config,
3263 OpConversionMode::Partial);
3264 return opConverter.convertOperations(ops);
3266 LogicalResult
3267 mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
3268 const FrozenRewritePatternSet &patterns,
3269 ConversionConfig config) {
3270 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3273 //===----------------------------------------------------------------------===//
3274 // Full Conversion
3276 LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
3277 const ConversionTarget &target,
3278 const FrozenRewritePatternSet &patterns,
3279 ConversionConfig config) {
3280 OperationConverter opConverter(target, patterns, config,
3281 OpConversionMode::Full);
3282 return opConverter.convertOperations(ops);
3284 LogicalResult mlir::applyFullConversion(Operation *op,
3285 const ConversionTarget &target,
3286 const FrozenRewritePatternSet &patterns,
3287 ConversionConfig config) {
3288 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3291 //===----------------------------------------------------------------------===//
3292 // Analysis Conversion
3294 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3295 /// op is a top-level module op (which is expected to be isolated from above),
3296 /// return that op.
3297 static Operation *findCommonAncestor(ArrayRef<Operation *> ops) {
3298 // Check if there is a top-level operation within `ops`. If so, return that
3299 // op.
3300 for (Operation *op : ops) {
3301 if (!op->getParentOp()) {
3302 #ifndef NDEBUG
3303 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3304 "expected top-level op to be isolated from above");
3305 for (Operation *other : ops)
3306 assert(op->isAncestor(other) &&
3307 "expected ops to have a common ancestor");
3308 #endif // NDEBUG
3309 return op;
3313 // No top-level op. Find a common ancestor.
3314 Operation *commonAncestor =
3315 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3316 for (Operation *op : ops.drop_front()) {
3317 while (!commonAncestor->isProperAncestor(op)) {
3318 commonAncestor =
3319 commonAncestor->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3320 assert(commonAncestor &&
3321 "expected to find a common isolated from above ancestor");
3325 return commonAncestor;
3328 LogicalResult mlir::applyAnalysisConversion(
3329 ArrayRef<Operation *> ops, ConversionTarget &target,
3330 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3331 #ifndef NDEBUG
3332 if (config.legalizableOps)
3333 assert(config.legalizableOps->empty() && "expected empty set");
3334 #endif // NDEBUG
3336 // Clone closted common ancestor that is isolated from above.
3337 Operation *commonAncestor = findCommonAncestor(ops);
3338 IRMapping mapping;
3339 Operation *clonedAncestor = commonAncestor->clone(mapping);
3340 // Compute inverse IR mapping.
3341 DenseMap<Operation *, Operation *> inverseOperationMap;
3342 for (auto &it : mapping.getOperationMap())
3343 inverseOperationMap[it.second] = it.first;
3345 // Convert the cloned operations. The original IR will remain unchanged.
3346 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3347 ops, [&](Operation *op) { return mapping.lookup(op); });
3348 OperationConverter opConverter(target, patterns, config,
3349 OpConversionMode::Analysis);
3350 LogicalResult status = opConverter.convertOperations(opsToConvert);
3352 // Remap `legalizableOps`, so that they point to the original ops and not the
3353 // cloned ops.
3354 if (config.legalizableOps) {
3355 DenseSet<Operation *> originalLegalizableOps;
3356 for (Operation *op : *config.legalizableOps)
3357 originalLegalizableOps.insert(inverseOperationMap[op]);
3358 *config.legalizableOps = std::move(originalLegalizableOps);
3361 // Erase the cloned IR.
3362 clonedAncestor->erase();
3363 return status;
3366 LogicalResult
3367 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
3368 const FrozenRewritePatternSet &patterns,
3369 ConversionConfig config) {
3370 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);