1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
16 #include "mlir/Interfaces/FunctionInterfaces.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
28 using namespace mlir::detail
;
30 #define DEBUG_TYPE "dialect-conversion"
32 /// A utility function to log a successful result for the given reason.
33 template <typename
... Args
>
34 static void logSuccess(llvm::ScopedPrinter
&os
, StringRef fmt
, Args
&&...args
) {
37 os
.startLine() << "} -> SUCCESS";
39 os
.getOStream() << " : "
40 << llvm::formatv(fmt
.data(), std::forward
<Args
>(args
)...);
41 os
.getOStream() << "\n";
45 /// A utility function to log a failure result for the given reason.
46 template <typename
... Args
>
47 static void logFailure(llvm::ScopedPrinter
&os
, StringRef fmt
, Args
&&...args
) {
50 os
.startLine() << "} -> FAILURE : "
51 << llvm::formatv(fmt
.data(), std::forward
<Args
>(args
)...)
56 /// Helper function that computes an insertion point where the given value is
57 /// defined and can be used without a dominance violation.
58 static OpBuilder::InsertPoint
computeInsertPoint(Value value
) {
59 Block
*insertBlock
= value
.getParentBlock();
60 Block::iterator insertPt
= insertBlock
->begin();
61 if (OpResult inputRes
= dyn_cast
<OpResult
>(value
))
62 insertPt
= ++inputRes
.getOwner()->getIterator();
63 return OpBuilder::InsertPoint(insertBlock
, insertPt
);
66 //===----------------------------------------------------------------------===//
67 // ConversionValueMapping
68 //===----------------------------------------------------------------------===//
70 /// A list of replacement SSA values. Optimized for the common case of a single
72 using ReplacementValues
= SmallVector
<Value
, 1>;
75 /// This class wraps a IRMapping to provide recursive lookup
76 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
77 struct ConversionValueMapping
{
78 /// Return "true" if an SSA value is mapped to the given value. May return
80 bool isMappedTo(Value value
) const { return mappedTo
.contains(value
); }
82 /// Lookup the most recently mapped value with the desired type in the
86 /// - If the desired type is "null", simply return the most recently mapped
88 /// - If there is no mapping to the desired type, also return the most
89 /// recently mapped value.
90 /// - If there is no mapping for the given value at all, return the given
92 Value
lookupOrDefault(Value from
, Type desiredType
= nullptr) const;
94 /// Lookup a mapped value within the map, or return null if a mapping does not
95 /// exist. If a mapping exists, this follows the same behavior of
96 /// `lookupOrDefault`.
97 Value
lookupOrNull(Value from
, Type desiredType
= nullptr) const;
99 /// Map a value to the one provided.
100 void map(Value oldVal
, Value newVal
) {
102 for (Value it
= newVal
; it
; it
= mapping
.lookupOrNull(it
))
103 assert(it
!= oldVal
&& "inserting cyclic mapping");
105 mapping
.map(oldVal
, newVal
);
106 mappedTo
.insert(newVal
);
109 /// Drop the last mapping for the given value.
110 void erase(Value value
) { mapping
.erase(value
); }
113 /// Current value mappings.
116 /// All SSA values that are mapped to. May contain false positives.
117 DenseSet
<Value
> mappedTo
;
121 Value
ConversionValueMapping::lookupOrDefault(Value from
,
122 Type desiredType
) const {
123 // Try to find the deepest value that has the desired type. If there is no
124 // such value, simply return the deepest value.
127 if (!desiredType
|| from
.getType() == desiredType
)
130 Value mappedValue
= mapping
.lookupOrNull(from
);
136 // If the desired value was found use it, otherwise default to the leaf value.
137 return desiredValue
? desiredValue
: from
;
140 Value
ConversionValueMapping::lookupOrNull(Value from
, Type desiredType
) const {
141 Value result
= lookupOrDefault(from
, desiredType
);
142 if (result
== from
|| (desiredType
&& result
.getType() != desiredType
))
147 //===----------------------------------------------------------------------===//
148 // Rewriter and Translation State
149 //===----------------------------------------------------------------------===//
151 /// This class contains a snapshot of the current conversion rewriter state.
152 /// This is useful when saving and undoing a set of rewrites.
153 struct RewriterState
{
154 RewriterState(unsigned numRewrites
, unsigned numIgnoredOperations
,
155 unsigned numReplacedOps
)
156 : numRewrites(numRewrites
), numIgnoredOperations(numIgnoredOperations
),
157 numReplacedOps(numReplacedOps
) {}
159 /// The current number of rewrites performed.
160 unsigned numRewrites
;
162 /// The current number of ignored operations.
163 unsigned numIgnoredOperations
;
165 /// The current number of replaced ops that are scheduled for erasure.
166 unsigned numReplacedOps
;
169 //===----------------------------------------------------------------------===//
171 //===----------------------------------------------------------------------===//
173 /// An IR rewrite that can be committed (upon success) or rolled back (upon
176 /// The dialect conversion keeps track of IR modifications (requested by the
177 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
178 /// are directly applied to the IR as the rewriter API is used, some are applied
179 /// partially, and some are delayed until the `IRRewrite` objects are committed.
182 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
183 /// Enum values are ordered, so that they can be used in `classof`: first all
184 /// block rewrites, then all operation rewrites.
193 // Operation rewrites
198 UnresolvedMaterialization
201 virtual ~IRRewrite() = default;
203 /// Roll back the rewrite. Operations may be erased during rollback.
204 virtual void rollback() = 0;
206 /// Commit the rewrite. At this point, it is certain that the dialect
207 /// conversion will succeed. All IR modifications, except for operation/block
208 /// erasure, must be performed through the given rewriter.
210 /// Instead of erasing operations/blocks, they should merely be unlinked
211 /// commit phase and finally be erased during the cleanup phase. This is
212 /// because internal dialect conversion state (such as `mapping`) may still
215 /// Any IR modification that was already performed before the commit phase
216 /// (e.g., insertion of an op) must be communicated to the listener that may
217 /// be attached to the given rewriter.
218 virtual void commit(RewriterBase
&rewriter
) {}
220 /// Cleanup operations/blocks. Cleanup is called after commit.
221 virtual void cleanup(RewriterBase
&rewriter
) {}
223 Kind
getKind() const { return kind
; }
225 static bool classof(const IRRewrite
*rewrite
) { return true; }
228 IRRewrite(Kind kind
, ConversionPatternRewriterImpl
&rewriterImpl
)
229 : kind(kind
), rewriterImpl(rewriterImpl
) {}
231 const ConversionConfig
&getConfig() const;
234 ConversionPatternRewriterImpl
&rewriterImpl
;
238 class BlockRewrite
: public IRRewrite
{
240 /// Return the block that this rewrite operates on.
241 Block
*getBlock() const { return block
; }
243 static bool classof(const IRRewrite
*rewrite
) {
244 return rewrite
->getKind() >= Kind::CreateBlock
&&
245 rewrite
->getKind() <= Kind::ReplaceBlockArg
;
249 BlockRewrite(Kind kind
, ConversionPatternRewriterImpl
&rewriterImpl
,
251 : IRRewrite(kind
, rewriterImpl
), block(block
) {}
253 // The block that this rewrite operates on.
257 /// Creation of a block. Block creations are immediately reflected in the IR.
258 /// There is no extra work to commit the rewrite. During rollback, the newly
259 /// created block is erased.
260 class CreateBlockRewrite
: public BlockRewrite
{
262 CreateBlockRewrite(ConversionPatternRewriterImpl
&rewriterImpl
, Block
*block
)
263 : BlockRewrite(Kind::CreateBlock
, rewriterImpl
, block
) {}
265 static bool classof(const IRRewrite
*rewrite
) {
266 return rewrite
->getKind() == Kind::CreateBlock
;
269 void commit(RewriterBase
&rewriter
) override
{
270 // The block was already created and inserted. Just inform the listener.
271 if (auto *listener
= rewriter
.getListener())
272 listener
->notifyBlockInserted(block
, /*previous=*/{}, /*previousIt=*/{});
275 void rollback() override
{
276 // Unlink all of the operations within this block, they will be deleted
278 auto &blockOps
= block
->getOperations();
279 while (!blockOps
.empty())
280 blockOps
.remove(blockOps
.begin());
281 block
->dropAllUses();
282 if (block
->getParent())
289 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
290 /// blocks are immediately unlinked, but only erased during cleanup. This makes
291 /// it easier to rollback a block erasure: the block is simply inserted into its
292 /// original location.
293 class EraseBlockRewrite
: public BlockRewrite
{
295 EraseBlockRewrite(ConversionPatternRewriterImpl
&rewriterImpl
, Block
*block
)
296 : BlockRewrite(Kind::EraseBlock
, rewriterImpl
, block
),
297 region(block
->getParent()), insertBeforeBlock(block
->getNextNode()) {}
299 static bool classof(const IRRewrite
*rewrite
) {
300 return rewrite
->getKind() == Kind::EraseBlock
;
303 ~EraseBlockRewrite() override
{
305 "rewrite was neither rolled back nor committed/cleaned up");
308 void rollback() override
{
309 // The block (owned by this rewrite) was not actually erased yet. It was
310 // just unlinked. Put it back into its original position.
311 assert(block
&& "expected block");
312 auto &blockList
= region
->getBlocks();
313 Region::iterator before
= insertBeforeBlock
314 ? Region::iterator(insertBeforeBlock
)
316 blockList
.insert(before
, block
);
320 void commit(RewriterBase
&rewriter
) override
{
322 assert(block
&& "expected block");
323 assert(block
->empty() && "expected empty block");
325 // Notify the listener that the block is about to be erased.
327 dyn_cast_or_null
<RewriterBase::Listener
>(rewriter
.getListener()))
328 listener
->notifyBlockErased(block
);
331 void cleanup(RewriterBase
&rewriter
) override
{
333 block
->dropAllDefinedValueUses();
339 // The region in which this block was previously contained.
342 // The original successor of this block before it was unlinked. "nullptr" if
343 // this block was the only block in the region.
344 Block
*insertBeforeBlock
;
347 /// Inlining of a block. This rewrite is immediately reflected in the IR.
348 /// Note: This rewrite represents only the inlining of the operations. The
349 /// erasure of the inlined block is a separate rewrite.
350 class InlineBlockRewrite
: public BlockRewrite
{
352 InlineBlockRewrite(ConversionPatternRewriterImpl
&rewriterImpl
, Block
*block
,
353 Block
*sourceBlock
, Block::iterator before
)
354 : BlockRewrite(Kind::InlineBlock
, rewriterImpl
, block
),
355 sourceBlock(sourceBlock
),
356 firstInlinedInst(sourceBlock
->empty() ? nullptr
357 : &sourceBlock
->front()),
358 lastInlinedInst(sourceBlock
->empty() ? nullptr : &sourceBlock
->back()) {
359 // If a listener is attached to the dialect conversion, ops must be moved
360 // one-by-one. When they are moved in bulk, notifications cannot be sent
361 // because the ops that used to be in the source block at the time of the
362 // inlining (before the "commit" phase) are unknown at the time when
363 // notifications are sent (which is during the "commit" phase).
364 assert(!getConfig().listener
&&
365 "InlineBlockRewrite not supported if listener is attached");
368 static bool classof(const IRRewrite
*rewrite
) {
369 return rewrite
->getKind() == Kind::InlineBlock
;
372 void rollback() override
{
373 // Put the operations from the destination block (owned by the rewrite)
374 // back into the source block.
375 if (firstInlinedInst
) {
376 assert(lastInlinedInst
&& "expected operation");
377 sourceBlock
->getOperations().splice(sourceBlock
->begin(),
378 block
->getOperations(),
379 Block::iterator(firstInlinedInst
),
380 ++Block::iterator(lastInlinedInst
));
385 // The block that originally contained the operations.
388 // The first inlined operation.
389 Operation
*firstInlinedInst
;
391 // The last inlined operation.
392 Operation
*lastInlinedInst
;
395 /// Moving of a block. This rewrite is immediately reflected in the IR.
396 class MoveBlockRewrite
: public BlockRewrite
{
398 MoveBlockRewrite(ConversionPatternRewriterImpl
&rewriterImpl
, Block
*block
,
399 Region
*region
, Block
*insertBeforeBlock
)
400 : BlockRewrite(Kind::MoveBlock
, rewriterImpl
, block
), region(region
),
401 insertBeforeBlock(insertBeforeBlock
) {}
403 static bool classof(const IRRewrite
*rewrite
) {
404 return rewrite
->getKind() == Kind::MoveBlock
;
407 void commit(RewriterBase
&rewriter
) override
{
408 // The block was already moved. Just inform the listener.
409 if (auto *listener
= rewriter
.getListener()) {
410 // Note: `previousIt` cannot be passed because this is a delayed
411 // notification and iterators into past IR state cannot be represented.
412 listener
->notifyBlockInserted(block
, /*previous=*/region
,
417 void rollback() override
{
418 // Move the block back to its original position.
419 Region::iterator before
=
420 insertBeforeBlock
? Region::iterator(insertBeforeBlock
) : region
->end();
421 region
->getBlocks().splice(before
, block
->getParent()->getBlocks(), block
);
425 // The region in which this block was previously contained.
428 // The original successor of this block before it was moved. "nullptr" if
429 // this block was the only block in the region.
430 Block
*insertBeforeBlock
;
433 /// Block type conversion. This rewrite is partially reflected in the IR.
434 class BlockTypeConversionRewrite
: public BlockRewrite
{
436 BlockTypeConversionRewrite(ConversionPatternRewriterImpl
&rewriterImpl
,
437 Block
*block
, Block
*origBlock
)
438 : BlockRewrite(Kind::BlockTypeConversion
, rewriterImpl
, block
),
439 origBlock(origBlock
) {}
441 static bool classof(const IRRewrite
*rewrite
) {
442 return rewrite
->getKind() == Kind::BlockTypeConversion
;
445 Block
*getOrigBlock() const { return origBlock
; }
447 void commit(RewriterBase
&rewriter
) override
;
449 void rollback() override
;
452 /// The original block that was requested to have its signature converted.
456 /// Replacing a block argument. This rewrite is not immediately reflected in the
457 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
458 /// until the rewrite is committed.
459 class ReplaceBlockArgRewrite
: public BlockRewrite
{
461 ReplaceBlockArgRewrite(ConversionPatternRewriterImpl
&rewriterImpl
,
462 Block
*block
, BlockArgument arg
,
463 const TypeConverter
*converter
)
464 : BlockRewrite(Kind::ReplaceBlockArg
, rewriterImpl
, block
), arg(arg
),
465 converter(converter
) {}
467 static bool classof(const IRRewrite
*rewrite
) {
468 return rewrite
->getKind() == Kind::ReplaceBlockArg
;
471 void commit(RewriterBase
&rewriter
) override
;
473 void rollback() override
;
478 /// The current type converter when the block argument was replaced.
479 const TypeConverter
*converter
;
482 /// An operation rewrite.
483 class OperationRewrite
: public IRRewrite
{
485 /// Return the operation that this rewrite operates on.
486 Operation
*getOperation() const { return op
; }
488 static bool classof(const IRRewrite
*rewrite
) {
489 return rewrite
->getKind() >= Kind::MoveOperation
&&
490 rewrite
->getKind() <= Kind::UnresolvedMaterialization
;
494 OperationRewrite(Kind kind
, ConversionPatternRewriterImpl
&rewriterImpl
,
496 : IRRewrite(kind
, rewriterImpl
), op(op
) {}
498 // The operation that this rewrite operates on.
502 /// Moving of an operation. This rewrite is immediately reflected in the IR.
503 class MoveOperationRewrite
: public OperationRewrite
{
505 MoveOperationRewrite(ConversionPatternRewriterImpl
&rewriterImpl
,
506 Operation
*op
, Block
*block
, Operation
*insertBeforeOp
)
507 : OperationRewrite(Kind::MoveOperation
, rewriterImpl
, op
), block(block
),
508 insertBeforeOp(insertBeforeOp
) {}
510 static bool classof(const IRRewrite
*rewrite
) {
511 return rewrite
->getKind() == Kind::MoveOperation
;
514 void commit(RewriterBase
&rewriter
) override
{
515 // The operation was already moved. Just inform the listener.
516 if (auto *listener
= rewriter
.getListener()) {
517 // Note: `previousIt` cannot be passed because this is a delayed
518 // notification and iterators into past IR state cannot be represented.
519 listener
->notifyOperationInserted(
520 op
, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block
,
525 void rollback() override
{
526 // Move the operation back to its original position.
527 Block::iterator before
=
528 insertBeforeOp
? Block::iterator(insertBeforeOp
) : block
->end();
529 block
->getOperations().splice(before
, op
->getBlock()->getOperations(), op
);
533 // The block in which this operation was previously contained.
536 // The original successor of this operation before it was moved. "nullptr"
537 // if this operation was the only operation in the region.
538 Operation
*insertBeforeOp
;
541 /// In-place modification of an op. This rewrite is immediately reflected in
542 /// the IR. The previous state of the operation is stored in this object.
543 class ModifyOperationRewrite
: public OperationRewrite
{
545 ModifyOperationRewrite(ConversionPatternRewriterImpl
&rewriterImpl
,
547 : OperationRewrite(Kind::ModifyOperation
, rewriterImpl
, op
),
548 name(op
->getName()), loc(op
->getLoc()), attrs(op
->getAttrDictionary()),
549 operands(op
->operand_begin(), op
->operand_end()),
550 successors(op
->successor_begin(), op
->successor_end()) {
551 if (OpaqueProperties prop
= op
->getPropertiesStorage()) {
552 // Make a copy of the properties.
553 propertiesStorage
= operator new(op
->getPropertiesStorageSize());
554 OpaqueProperties
propCopy(propertiesStorage
);
555 name
.initOpProperties(propCopy
, /*init=*/prop
);
559 static bool classof(const IRRewrite
*rewrite
) {
560 return rewrite
->getKind() == Kind::ModifyOperation
;
563 ~ModifyOperationRewrite() override
{
564 assert(!propertiesStorage
&&
565 "rewrite was neither committed nor rolled back");
568 void commit(RewriterBase
&rewriter
) override
{
569 // Notify the listener that the operation was modified in-place.
571 dyn_cast_or_null
<RewriterBase::Listener
>(rewriter
.getListener()))
572 listener
->notifyOperationModified(op
);
574 if (propertiesStorage
) {
575 OpaqueProperties
propCopy(propertiesStorage
);
576 // Note: The operation may have been erased in the mean time, so
577 // OperationName must be stored in this object.
578 name
.destroyOpProperties(propCopy
);
579 operator delete(propertiesStorage
);
580 propertiesStorage
= nullptr;
584 void rollback() override
{
587 op
->setOperands(operands
);
588 for (const auto &it
: llvm::enumerate(successors
))
589 op
->setSuccessor(it
.value(), it
.index());
590 if (propertiesStorage
) {
591 OpaqueProperties
propCopy(propertiesStorage
);
592 op
->copyProperties(propCopy
);
593 name
.destroyOpProperties(propCopy
);
594 operator delete(propertiesStorage
);
595 propertiesStorage
= nullptr;
602 DictionaryAttr attrs
;
603 SmallVector
<Value
, 8> operands
;
604 SmallVector
<Block
*, 2> successors
;
605 void *propertiesStorage
= nullptr;
608 /// Replacing an operation. Erasing an operation is treated as a special case
609 /// with "null" replacements. This rewrite is not immediately reflected in the
610 /// IR. An internal IR mapping is updated, but values are not replaced and the
611 /// original op is not erased until the rewrite is committed.
612 class ReplaceOperationRewrite
: public OperationRewrite
{
614 ReplaceOperationRewrite(ConversionPatternRewriterImpl
&rewriterImpl
,
615 Operation
*op
, const TypeConverter
*converter
)
616 : OperationRewrite(Kind::ReplaceOperation
, rewriterImpl
, op
),
617 converter(converter
) {}
619 static bool classof(const IRRewrite
*rewrite
) {
620 return rewrite
->getKind() == Kind::ReplaceOperation
;
623 void commit(RewriterBase
&rewriter
) override
;
625 void rollback() override
;
627 void cleanup(RewriterBase
&rewriter
) override
;
630 /// An optional type converter that can be used to materialize conversions
631 /// between the new and old values if necessary.
632 const TypeConverter
*converter
;
635 class CreateOperationRewrite
: public OperationRewrite
{
637 CreateOperationRewrite(ConversionPatternRewriterImpl
&rewriterImpl
,
639 : OperationRewrite(Kind::CreateOperation
, rewriterImpl
, op
) {}
641 static bool classof(const IRRewrite
*rewrite
) {
642 return rewrite
->getKind() == Kind::CreateOperation
;
645 void commit(RewriterBase
&rewriter
) override
{
646 // The operation was already created and inserted. Just inform the listener.
647 if (auto *listener
= rewriter
.getListener())
648 listener
->notifyOperationInserted(op
, /*previous=*/{});
651 void rollback() override
;
654 /// The type of materialization.
655 enum MaterializationKind
{
656 /// This materialization materializes a conversion for an illegal block
657 /// argument type, to the original one.
660 /// This materialization materializes a conversion from an illegal type to a
664 /// This materialization materializes a conversion from a legal type back to
669 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
670 /// op. Unresolved materializations are erased at the end of the dialect
672 class UnresolvedMaterializationRewrite
: public OperationRewrite
{
674 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl
&rewriterImpl
,
675 UnrealizedConversionCastOp op
,
676 const TypeConverter
*converter
,
677 MaterializationKind kind
, Type originalType
);
679 static bool classof(const IRRewrite
*rewrite
) {
680 return rewrite
->getKind() == Kind::UnresolvedMaterialization
;
683 void rollback() override
;
685 UnrealizedConversionCastOp
getOperation() const {
686 return cast
<UnrealizedConversionCastOp
>(op
);
689 /// Return the type converter of this materialization (which may be null).
690 const TypeConverter
*getConverter() const {
691 return converterAndKind
.getPointer();
694 /// Return the kind of this materialization.
695 MaterializationKind
getMaterializationKind() const {
696 return converterAndKind
.getInt();
699 /// Return the original type of the SSA value.
700 Type
getOriginalType() const { return originalType
; }
703 /// The corresponding type converter to use when resolving this
704 /// materialization, and the kind of this materialization.
705 llvm::PointerIntPair
<const TypeConverter
*, 2, MaterializationKind
>
708 /// The original type of the SSA value. Only used for target
709 /// materializations.
714 /// Return "true" if there is an operation rewrite that matches the specified
715 /// rewrite type and operation among the given rewrites.
716 template <typename RewriteTy
, typename R
>
717 static bool hasRewrite(R
&&rewrites
, Operation
*op
) {
718 return any_of(std::forward
<R
>(rewrites
), [&](auto &rewrite
) {
719 auto *rewriteTy
= dyn_cast
<RewriteTy
>(rewrite
.get());
720 return rewriteTy
&& rewriteTy
->getOperation() == op
;
724 //===----------------------------------------------------------------------===//
725 // ConversionPatternRewriterImpl
726 //===----------------------------------------------------------------------===//
729 struct ConversionPatternRewriterImpl
: public RewriterBase::Listener
{
730 explicit ConversionPatternRewriterImpl(MLIRContext
*ctx
,
731 const ConversionConfig
&config
)
732 : context(ctx
), eraseRewriter(ctx
), config(config
) {}
734 //===--------------------------------------------------------------------===//
736 //===--------------------------------------------------------------------===//
738 /// Return the current state of the rewriter.
739 RewriterState
getCurrentState();
741 /// Apply all requested operation rewrites. This method is invoked when the
742 /// conversion process succeeds.
743 void applyRewrites();
745 /// Reset the state of the rewriter to a previously saved point.
746 void resetState(RewriterState state
);
748 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
750 template <typename RewriteTy
, typename
... Args
>
751 void appendRewrite(Args
&&...args
) {
753 std::make_unique
<RewriteTy
>(*this, std::forward
<Args
>(args
)...));
756 /// Undo the rewrites (motions, splits) one by one in reverse order until
757 /// "numRewritesToKeep" rewrites remains.
758 void undoRewrites(unsigned numRewritesToKeep
= 0);
760 /// Remap the given values to those with potentially different types. Returns
761 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
762 /// is the tag used when describing a value within a diagnostic, e.g.
764 LogicalResult
remapValues(StringRef valueDiagTag
,
765 std::optional
<Location
> inputLoc
,
766 PatternRewriter
&rewriter
, ValueRange values
,
767 SmallVectorImpl
<Value
> &remapped
);
769 /// Return "true" if the given operation is ignored, and does not need to be
771 bool isOpIgnored(Operation
*op
) const;
773 /// Return "true" if the given operation was replaced or erased.
774 bool wasOpReplaced(Operation
*op
) const;
776 //===--------------------------------------------------------------------===//
778 //===--------------------------------------------------------------------===//
780 /// Convert the types of block arguments within the given region.
782 convertRegionTypes(ConversionPatternRewriter
&rewriter
, Region
*region
,
783 const TypeConverter
&converter
,
784 TypeConverter::SignatureConversion
*entryConversion
);
786 /// Apply the given signature conversion on the given block. The new block
787 /// containing the updated signature is returned. If no conversions were
788 /// necessary, e.g. if the block has no arguments, `block` is returned.
789 /// `converter` is used to generate any necessary cast operations that
790 /// translate between the origin argument types and those specified in the
791 /// signature conversion.
792 Block
*applySignatureConversion(
793 ConversionPatternRewriter
&rewriter
, Block
*block
,
794 const TypeConverter
*converter
,
795 TypeConverter::SignatureConversion
&signatureConversion
);
797 //===--------------------------------------------------------------------===//
799 //===--------------------------------------------------------------------===//
801 /// Build an unresolved materialization operation given an output type and set
802 /// of input operands.
803 Value
buildUnresolvedMaterialization(MaterializationKind kind
,
804 OpBuilder::InsertPoint ip
, Location loc
,
805 ValueRange inputs
, Type outputType
,
807 const TypeConverter
*converter
);
809 /// Build an N:1 materialization for the given original value that was
810 /// replaced with the given replacement values.
812 /// This is a workaround around incomplete 1:N support in the dialect
813 /// conversion driver. The conversion mapping can store only 1:1 replacements
814 /// and the conversion patterns only support single Value replacements in the
815 /// adaptor, so N values must be converted back to a single value. This
816 /// function will be deleted when full 1:N support has been added.
818 /// This function inserts an argument materialization back to the original
819 /// type, followed by a target materialization to the legalized type (if
821 void insertNTo1Materialization(OpBuilder::InsertPoint ip
, Location loc
,
822 ValueRange replacements
, Value originalValue
,
823 const TypeConverter
*converter
);
825 /// Find a replacement value for the given SSA value in the conversion value
826 /// mapping. The replacement value must have the same type as the given SSA
827 /// value. If there is no replacement value with the correct type, find the
828 /// latest replacement value (regardless of the type) and build a source
830 Value
findOrBuildReplacementValue(Value value
,
831 const TypeConverter
*converter
);
833 //===--------------------------------------------------------------------===//
834 // Rewriter Notification Hooks
835 //===--------------------------------------------------------------------===//
837 //// Notifies that an op was inserted.
838 void notifyOperationInserted(Operation
*op
,
839 OpBuilder::InsertPoint previous
) override
;
841 /// Notifies that an op is about to be replaced with the given values.
842 void notifyOpReplaced(Operation
*op
, ArrayRef
<ReplacementValues
> newValues
);
844 /// Notifies that a block is about to be erased.
845 void notifyBlockIsBeingErased(Block
*block
);
847 /// Notifies that a block was inserted.
848 void notifyBlockInserted(Block
*block
, Region
*previous
,
849 Region::iterator previousIt
) override
;
851 /// Notifies that a block is being inlined into another block.
852 void notifyBlockBeingInlined(Block
*block
, Block
*srcBlock
,
853 Block::iterator before
);
855 /// Notifies that a pattern match failed for the given reason.
857 notifyMatchFailure(Location loc
,
858 function_ref
<void(Diagnostic
&)> reasonCallback
) override
;
860 //===--------------------------------------------------------------------===//
862 //===--------------------------------------------------------------------===//
864 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
865 /// operation or block is erased multiple times. This rewriter assumes that
866 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
867 struct SingleEraseRewriter
: public RewriterBase
, RewriterBase::Listener
{
869 SingleEraseRewriter(MLIRContext
*context
)
870 : RewriterBase(context
, /*listener=*/this) {}
872 /// Erase the given op (unless it was already erased).
873 void eraseOp(Operation
*op
) override
{
877 RewriterBase::eraseOp(op
);
880 /// Erase the given block (unless it was already erased).
881 void eraseBlock(Block
*block
) override
{
882 if (wasErased(block
))
884 assert(block
->empty() && "expected empty block");
885 block
->dropAllDefinedValueUses();
886 RewriterBase::eraseBlock(block
);
889 bool wasErased(void *ptr
) const { return erased
.contains(ptr
); }
891 void notifyOperationErased(Operation
*op
) override
{ erased
.insert(op
); }
893 void notifyBlockErased(Block
*block
) override
{ erased
.insert(block
); }
896 /// Pointers to all erased operations and blocks.
897 DenseSet
<void *> erased
;
900 //===--------------------------------------------------------------------===//
902 //===--------------------------------------------------------------------===//
905 MLIRContext
*context
;
907 /// A rewriter that keeps track of ops/block that were already erased and
908 /// skips duplicate op/block erasures. This rewriter is used during the
910 SingleEraseRewriter eraseRewriter
;
912 // Mapping between replaced values that differ in type. This happens when
913 // replacing a value with one of a different type.
914 ConversionValueMapping mapping
;
916 /// Ordered list of block operations (creations, splits, motions).
917 SmallVector
<std::unique_ptr
<IRRewrite
>> rewrites
;
919 /// A set of operations that should no longer be considered for legalization.
920 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
921 /// tracked separately.
922 SetVector
<Operation
*> ignoredOps
;
924 /// A set of operations that were replaced/erased. Such ops are not erased
925 /// immediately but only when the dialect conversion succeeds. In the mean
926 /// time, they should no longer be considered for legalization and any attempt
927 /// to modify/access them is invalid rewriter API usage.
928 SetVector
<Operation
*> replacedOps
;
930 /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
931 /// to the corresponding rewrite objects.
932 DenseMap
<UnrealizedConversionCastOp
, UnresolvedMaterializationRewrite
*>
933 unresolvedMaterializations
;
935 /// The current type converter, or nullptr if no type converter is currently
937 const TypeConverter
*currentTypeConverter
= nullptr;
939 /// A mapping of regions to type converters that should be used when
940 /// converting the arguments of blocks within that region.
941 DenseMap
<Region
*, const TypeConverter
*> regionToConverter
;
943 /// Dialect conversion configuration.
944 const ConversionConfig
&config
;
947 /// A set of operations that have pending updates. This tracking isn't
948 /// strictly necessary, and is thus only active during debug builds for extra
950 SmallPtrSet
<Operation
*, 1> pendingRootUpdates
;
952 /// A logger used to emit diagnostics during the conversion process.
953 llvm::ScopedPrinter logger
{llvm::dbgs()};
956 } // namespace detail
959 const ConversionConfig
&IRRewrite::getConfig() const {
960 return rewriterImpl
.config
;
963 void BlockTypeConversionRewrite::commit(RewriterBase
&rewriter
) {
964 // Inform the listener about all IR modifications that have already taken
965 // place: References to the original block have been replaced with the new
968 dyn_cast_or_null
<RewriterBase::Listener
>(rewriter
.getListener()))
969 for (Operation
*op
: block
->getUsers())
970 listener
->notifyOperationModified(op
);
973 void BlockTypeConversionRewrite::rollback() {
974 block
->replaceAllUsesWith(origBlock
);
977 void ReplaceBlockArgRewrite::commit(RewriterBase
&rewriter
) {
978 Value repl
= rewriterImpl
.findOrBuildReplacementValue(arg
, converter
);
982 if (isa
<BlockArgument
>(repl
)) {
983 rewriter
.replaceAllUsesWith(arg
, repl
);
987 // If the replacement value is an operation, we check to make sure that we
988 // don't replace uses that are within the parent operation of the
989 // replacement value.
990 Operation
*replOp
= cast
<OpResult
>(repl
).getOwner();
991 Block
*replBlock
= replOp
->getBlock();
992 rewriter
.replaceUsesWithIf(arg
, repl
, [&](OpOperand
&operand
) {
993 Operation
*user
= operand
.getOwner();
994 return user
->getBlock() != replBlock
|| replOp
->isBeforeInBlock(user
);
998 void ReplaceBlockArgRewrite::rollback() { rewriterImpl
.mapping
.erase(arg
); }
1000 void ReplaceOperationRewrite::commit(RewriterBase
&rewriter
) {
1002 dyn_cast_or_null
<RewriterBase::Listener
>(rewriter
.getListener());
1004 // Compute replacement values.
1005 SmallVector
<Value
> replacements
=
1006 llvm::map_to_vector(op
->getResults(), [&](OpResult result
) {
1007 return rewriterImpl
.findOrBuildReplacementValue(result
, converter
);
1010 // Notify the listener that the operation is about to be replaced.
1012 listener
->notifyOperationReplaced(op
, replacements
);
1014 // Replace all uses with the new values.
1015 for (auto [result
, newValue
] :
1016 llvm::zip_equal(op
->getResults(), replacements
))
1018 rewriter
.replaceAllUsesWith(result
, newValue
);
1020 // The original op will be erased, so remove it from the set of unlegalized
1022 if (getConfig().unlegalizedOps
)
1023 getConfig().unlegalizedOps
->erase(op
);
1025 // Notify the listener that the operation (and its nested operations) was
1028 op
->walk
<WalkOrder::PostOrder
>(
1029 [&](Operation
*op
) { listener
->notifyOperationErased(op
); });
1032 // Do not erase the operation yet. It may still be referenced in `mapping`.
1033 // Just unlink it for now and erase it during cleanup.
1034 op
->getBlock()->getOperations().remove(op
);
1037 void ReplaceOperationRewrite::rollback() {
1038 for (auto result
: op
->getResults())
1039 rewriterImpl
.mapping
.erase(result
);
1042 void ReplaceOperationRewrite::cleanup(RewriterBase
&rewriter
) {
1043 rewriter
.eraseOp(op
);
1046 void CreateOperationRewrite::rollback() {
1047 for (Region
®ion
: op
->getRegions()) {
1048 while (!region
.getBlocks().empty())
1049 region
.getBlocks().remove(region
.getBlocks().begin());
1055 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1056 ConversionPatternRewriterImpl
&rewriterImpl
, UnrealizedConversionCastOp op
,
1057 const TypeConverter
*converter
, MaterializationKind kind
, Type originalType
)
1058 : OperationRewrite(Kind::UnresolvedMaterialization
, rewriterImpl
, op
),
1059 converterAndKind(converter
, kind
), originalType(originalType
) {
1060 assert((!originalType
|| kind
== MaterializationKind::Target
) &&
1061 "original type is valid only for target materializations");
1062 rewriterImpl
.unresolvedMaterializations
[op
] = this;
1065 void UnresolvedMaterializationRewrite::rollback() {
1066 if (getMaterializationKind() == MaterializationKind::Target
) {
1067 for (Value input
: op
->getOperands())
1068 rewriterImpl
.mapping
.erase(input
);
1070 rewriterImpl
.unresolvedMaterializations
.erase(getOperation());
1074 void ConversionPatternRewriterImpl::applyRewrites() {
1075 // Commit all rewrites.
1076 IRRewriter
rewriter(context
, config
.listener
);
1077 // Note: New rewrites may be added during the "commit" phase and the
1078 // `rewrites` vector may reallocate.
1079 for (size_t i
= 0; i
< rewrites
.size(); ++i
)
1080 rewrites
[i
]->commit(rewriter
);
1082 // Clean up all rewrites.
1083 for (auto &rewrite
: rewrites
)
1084 rewrite
->cleanup(eraseRewriter
);
1087 //===----------------------------------------------------------------------===//
1090 RewriterState
ConversionPatternRewriterImpl::getCurrentState() {
1091 return RewriterState(rewrites
.size(), ignoredOps
.size(), replacedOps
.size());
1094 void ConversionPatternRewriterImpl::resetState(RewriterState state
) {
1095 // Undo any rewrites.
1096 undoRewrites(state
.numRewrites
);
1098 // Pop all of the recorded ignored operations that are no longer valid.
1099 while (ignoredOps
.size() != state
.numIgnoredOperations
)
1100 ignoredOps
.pop_back();
1102 while (replacedOps
.size() != state
.numReplacedOps
)
1103 replacedOps
.pop_back();
1106 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep
) {
1107 for (auto &rewrite
:
1108 llvm::reverse(llvm::drop_begin(rewrites
, numRewritesToKeep
)))
1109 rewrite
->rollback();
1110 rewrites
.resize(numRewritesToKeep
);
1113 LogicalResult
ConversionPatternRewriterImpl::remapValues(
1114 StringRef valueDiagTag
, std::optional
<Location
> inputLoc
,
1115 PatternRewriter
&rewriter
, ValueRange values
,
1116 SmallVectorImpl
<Value
> &remapped
) {
1117 remapped
.reserve(llvm::size(values
));
1119 for (const auto &it
: llvm::enumerate(values
)) {
1120 Value operand
= it
.value();
1121 Type origType
= operand
.getType();
1122 Location operandLoc
= inputLoc
? *inputLoc
: operand
.getLoc();
1124 if (!currentTypeConverter
) {
1125 // The current pattern does not have a type converter. I.e., it does not
1126 // distinguish between legal and illegal types. For each operand, simply
1127 // pass through the most recently mapped value.
1128 remapped
.push_back(mapping
.lookupOrDefault(operand
));
1132 // If there is no legal conversion, fail to match this pattern.
1133 SmallVector
<Type
, 1> legalTypes
;
1134 if (failed(currentTypeConverter
->convertType(origType
, legalTypes
))) {
1135 notifyMatchFailure(operandLoc
, [=](Diagnostic
&diag
) {
1136 diag
<< "unable to convert type for " << valueDiagTag
<< " #"
1137 << it
.index() << ", type was " << origType
;
1142 if (legalTypes
.size() != 1) {
1143 // TODO: Parts of the dialect conversion infrastructure do not support
1144 // 1->N type conversions yet. Therefore, if a type is converted to 0 or
1145 // multiple types, the only thing that we can do for now is passing
1146 // through the most recently mapped value. Fixing this requires
1147 // improvements to the `ConversionValueMapping` (to be able to store 1:N
1148 // mappings) and to the `ConversionPattern` adaptor handling (to be able
1149 // to pass multiple remapped values for a single operand to the adaptor).
1150 remapped
.push_back(mapping
.lookupOrDefault(operand
));
1154 // Handle 1->1 type conversions.
1155 Type desiredType
= legalTypes
.front();
1156 // Try to find a mapped value with the desired type. (Or the operand itself
1157 // if the value is not mapped at all.)
1158 Value newOperand
= mapping
.lookupOrDefault(operand
, desiredType
);
1159 if (newOperand
.getType() != desiredType
) {
1160 // If the looked up value's type does not have the desired type, it means
1161 // that the value was replaced with a value of different type and no
1162 // source materialization was created yet.
1163 Value castValue
= buildUnresolvedMaterialization(
1164 MaterializationKind::Target
, computeInsertPoint(newOperand
),
1166 /*inputs=*/newOperand
, /*outputType=*/desiredType
,
1167 /*originalType=*/origType
, currentTypeConverter
);
1168 mapping
.map(newOperand
, castValue
);
1169 newOperand
= castValue
;
1171 remapped
.push_back(newOperand
);
1176 bool ConversionPatternRewriterImpl::isOpIgnored(Operation
*op
) const {
1177 // Check to see if this operation is ignored or was replaced.
1178 return replacedOps
.count(op
) || ignoredOps
.count(op
);
1181 bool ConversionPatternRewriterImpl::wasOpReplaced(Operation
*op
) const {
1182 // Check to see if this operation was replaced.
1183 return replacedOps
.count(op
);
1186 //===----------------------------------------------------------------------===//
1189 FailureOr
<Block
*> ConversionPatternRewriterImpl::convertRegionTypes(
1190 ConversionPatternRewriter
&rewriter
, Region
*region
,
1191 const TypeConverter
&converter
,
1192 TypeConverter::SignatureConversion
*entryConversion
) {
1193 regionToConverter
[region
] = &converter
;
1194 if (region
->empty())
1197 // Convert the arguments of each non-entry block within the region.
1199 llvm::make_early_inc_range(llvm::drop_begin(*region
, 1))) {
1200 // Compute the signature for the block with the provided converter.
1201 std::optional
<TypeConverter::SignatureConversion
> conversion
=
1202 converter
.convertBlockSignature(&block
);
1205 // Convert the block with the computed signature.
1206 applySignatureConversion(rewriter
, &block
, &converter
, *conversion
);
1209 // Convert the entry block. If an entry signature conversion was provided,
1210 // use that one. Otherwise, compute the signature with the type converter.
1211 if (entryConversion
)
1212 return applySignatureConversion(rewriter
, ®ion
->front(), &converter
,
1214 std::optional
<TypeConverter::SignatureConversion
> conversion
=
1215 converter
.convertBlockSignature(®ion
->front());
1218 return applySignatureConversion(rewriter
, ®ion
->front(), &converter
,
1222 Block
*ConversionPatternRewriterImpl::applySignatureConversion(
1223 ConversionPatternRewriter
&rewriter
, Block
*block
,
1224 const TypeConverter
*converter
,
1225 TypeConverter::SignatureConversion
&signatureConversion
) {
1226 OpBuilder::InsertionGuard
g(rewriter
);
1228 // If no arguments are being changed or added, there is nothing to do.
1229 unsigned origArgCount
= block
->getNumArguments();
1230 auto convertedTypes
= signatureConversion
.getConvertedTypes();
1231 if (llvm::equal(block
->getArgumentTypes(), convertedTypes
))
1234 // Compute the locations of all block arguments in the new block.
1235 SmallVector
<Location
> newLocs(convertedTypes
.size(),
1236 rewriter
.getUnknownLoc());
1237 for (unsigned i
= 0; i
< origArgCount
; ++i
) {
1238 auto inputMap
= signatureConversion
.getInputMapping(i
);
1239 if (!inputMap
|| inputMap
->replacementValue
)
1241 Location origLoc
= block
->getArgument(i
).getLoc();
1242 for (unsigned j
= 0; j
< inputMap
->size
; ++j
)
1243 newLocs
[inputMap
->inputNo
+ j
] = origLoc
;
1246 // Insert a new block with the converted block argument types and move all ops
1247 // from the old block to the new block.
1249 rewriter
.createBlock(block
->getParent(), std::next(block
->getIterator()),
1250 convertedTypes
, newLocs
);
1252 // If a listener is attached to the dialect conversion, ops cannot be moved
1253 // to the destination block in bulk ("fast path"). This is because at the time
1254 // the notifications are sent, it is unknown which ops were moved. Instead,
1255 // ops should be moved one-by-one ("slow path"), so that a separate
1256 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1257 // a bit more efficient, so we try to do that when possible.
1258 bool fastPath
= !config
.listener
;
1260 appendRewrite
<InlineBlockRewrite
>(newBlock
, block
, newBlock
->end());
1261 newBlock
->getOperations().splice(newBlock
->end(), block
->getOperations());
1263 while (!block
->empty())
1264 rewriter
.moveOpBefore(&block
->front(), newBlock
, newBlock
->end());
1267 // Replace all uses of the old block with the new block.
1268 block
->replaceAllUsesWith(newBlock
);
1270 for (unsigned i
= 0; i
!= origArgCount
; ++i
) {
1271 BlockArgument origArg
= block
->getArgument(i
);
1272 Type origArgType
= origArg
.getType();
1274 std::optional
<TypeConverter::SignatureConversion::InputMapping
> inputMap
=
1275 signatureConversion
.getInputMapping(i
);
1277 // This block argument was dropped and no replacement value was provided.
1278 // Materialize a replacement value "out of thin air".
1279 Value repl
= buildUnresolvedMaterialization(
1280 MaterializationKind::Source
,
1281 OpBuilder::InsertPoint(newBlock
, newBlock
->begin()), origArg
.getLoc(),
1282 /*inputs=*/ValueRange(),
1283 /*outputType=*/origArgType
, /*originalType=*/Type(), converter
);
1284 mapping
.map(origArg
, repl
);
1285 appendRewrite
<ReplaceBlockArgRewrite
>(block
, origArg
, converter
);
1289 if (Value repl
= inputMap
->replacementValue
) {
1290 // This block argument was dropped and a replacement value was provided.
1291 assert(inputMap
->size
== 0 &&
1292 "invalid to provide a replacement value when the argument isn't "
1294 mapping
.map(origArg
, repl
);
1295 appendRewrite
<ReplaceBlockArgRewrite
>(block
, origArg
, converter
);
1299 // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1300 // dialect conversion. Therefore, we need an argument materialization to
1301 // turn the replacement block arguments into a single SSA value that can be
1302 // used as a replacement.
1304 newBlock
->getArguments().slice(inputMap
->inputNo
, inputMap
->size
);
1305 insertNTo1Materialization(
1306 OpBuilder::InsertPoint(newBlock
, newBlock
->begin()), origArg
.getLoc(),
1307 /*replacements=*/replArgs
, /*outputValue=*/origArg
, converter
);
1308 appendRewrite
<ReplaceBlockArgRewrite
>(block
, origArg
, converter
);
1311 appendRewrite
<BlockTypeConversionRewrite
>(newBlock
, block
);
1313 // Erase the old block. (It is just unlinked for now and will be erased during
1315 rewriter
.eraseBlock(block
);
1320 //===----------------------------------------------------------------------===//
1322 //===----------------------------------------------------------------------===//
1324 /// Build an unresolved materialization operation given an output type and set
1325 /// of input operands.
1326 Value
ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1327 MaterializationKind kind
, OpBuilder::InsertPoint ip
, Location loc
,
1328 ValueRange inputs
, Type outputType
, Type originalType
,
1329 const TypeConverter
*converter
) {
1330 assert((!originalType
|| kind
== MaterializationKind::Target
) &&
1331 "original type is valid only for target materializations");
1333 // Avoid materializing an unnecessary cast.
1334 if (inputs
.size() == 1 && inputs
.front().getType() == outputType
)
1335 return inputs
.front();
1337 // Create an unresolved materialization. We use a new OpBuilder to avoid
1338 // tracking the materialization like we do for other operations.
1339 OpBuilder
builder(outputType
.getContext());
1340 builder
.setInsertionPoint(ip
.getBlock(), ip
.getPoint());
1342 builder
.create
<UnrealizedConversionCastOp
>(loc
, outputType
, inputs
);
1343 appendRewrite
<UnresolvedMaterializationRewrite
>(convertOp
, converter
, kind
,
1345 return convertOp
.getResult(0);
1348 void ConversionPatternRewriterImpl::insertNTo1Materialization(
1349 OpBuilder::InsertPoint ip
, Location loc
, ValueRange replacements
,
1350 Value originalValue
, const TypeConverter
*converter
) {
1351 // Insert argument materialization back to the original type.
1352 Type originalType
= originalValue
.getType();
1354 buildUnresolvedMaterialization(MaterializationKind::Argument
, ip
, loc
,
1355 /*inputs=*/replacements
, originalType
,
1356 /*originalType=*/Type(), converter
);
1357 mapping
.map(originalValue
, argMat
);
1359 // Insert target materialization to the legalized type.
1360 Type legalOutputType
;
1362 legalOutputType
= converter
->convertType(originalType
);
1363 } else if (replacements
.size() == 1) {
1364 // When there is no type converter, assume that the replacement value
1365 // types are legal. This is reasonable to assume because they were
1366 // specified by the user.
1367 // FIXME: This won't work for 1->N conversions because multiple output
1368 // types are not supported in parts of the dialect conversion. In such a
1369 // case, we currently use the original value type.
1370 legalOutputType
= replacements
[0].getType();
1372 if (legalOutputType
&& legalOutputType
!= originalType
) {
1373 Value targetMat
= buildUnresolvedMaterialization(
1374 MaterializationKind::Target
, computeInsertPoint(argMat
), loc
,
1375 /*inputs=*/argMat
, /*outputType=*/legalOutputType
,
1376 /*originalType=*/originalType
, converter
);
1377 mapping
.map(argMat
, targetMat
);
1381 Value
ConversionPatternRewriterImpl::findOrBuildReplacementValue(
1382 Value value
, const TypeConverter
*converter
) {
1383 // Find a replacement value with the same type.
1384 Value repl
= mapping
.lookupOrNull(value
, value
.getType());
1388 // Check if the value is dead. No replacement value is needed in that case.
1389 // This is an approximate check that may have false negatives but does not
1390 // require computing and traversing an inverse mapping. (We may end up
1391 // building source materializations that are never used and that fold away.)
1392 if (llvm::all_of(value
.getUsers(),
1393 [&](Operation
*op
) { return replacedOps
.contains(op
); }) &&
1394 !mapping
.isMappedTo(value
))
1397 // No replacement value was found. Get the latest replacement value
1398 // (regardless of the type) and build a source materialization to the
1400 repl
= mapping
.lookupOrNull(value
);
1402 // No replacement value is registered in the mapping. This means that the
1403 // value is dropped and no longer needed. (If the value were still needed,
1404 // a source materialization producing a replacement value "out of thin air"
1405 // would have already been created during `replaceOp` or
1406 // `applySignatureConversion`.)
1409 Value castValue
= buildUnresolvedMaterialization(
1410 MaterializationKind::Source
, computeInsertPoint(repl
), value
.getLoc(),
1411 /*inputs=*/repl
, /*outputType=*/value
.getType(),
1412 /*originalType=*/Type(), converter
);
1416 //===----------------------------------------------------------------------===//
1417 // Rewriter Notification Hooks
1419 void ConversionPatternRewriterImpl::notifyOperationInserted(
1420 Operation
*op
, OpBuilder::InsertPoint previous
) {
1422 logger
.startLine() << "** Insert : '" << op
->getName() << "'(" << op
1425 assert(!wasOpReplaced(op
->getParentOp()) &&
1426 "attempting to insert into a block within a replaced/erased op");
1428 if (!previous
.isSet()) {
1429 // This is a newly created op.
1430 appendRewrite
<CreateOperationRewrite
>(op
);
1433 Operation
*prevOp
= previous
.getPoint() == previous
.getBlock()->end()
1435 : &*previous
.getPoint();
1436 appendRewrite
<MoveOperationRewrite
>(op
, previous
.getBlock(), prevOp
);
1439 void ConversionPatternRewriterImpl::notifyOpReplaced(
1440 Operation
*op
, ArrayRef
<ReplacementValues
> newValues
) {
1441 assert(newValues
.size() == op
->getNumResults());
1442 assert(!ignoredOps
.contains(op
) && "operation was already replaced");
1444 // Check if replaced op is an unresolved materialization, i.e., an
1445 // unrealized_conversion_cast op that was created by the conversion driver.
1446 bool isUnresolvedMaterialization
= false;
1447 if (auto castOp
= dyn_cast
<UnrealizedConversionCastOp
>(op
))
1448 if (unresolvedMaterializations
.contains(castOp
))
1449 isUnresolvedMaterialization
= true;
1451 // Create mappings for each of the new result values.
1452 for (auto [n
, result
] : llvm::zip_equal(newValues
, op
->getResults())) {
1453 ReplacementValues repl
= n
;
1455 // This result was dropped and no replacement value was provided.
1456 if (isUnresolvedMaterialization
) {
1457 // Do not create another materializations if we are erasing a
1462 // Materialize a replacement value "out of thin air".
1463 Value sourceMat
= buildUnresolvedMaterialization(
1464 MaterializationKind::Source
, computeInsertPoint(result
),
1465 result
.getLoc(), /*inputs=*/ValueRange(),
1466 /*outputType=*/result
.getType(), /*originalType=*/Type(),
1467 currentTypeConverter
);
1468 repl
.push_back(sourceMat
);
1470 // Make sure that the user does not mess with unresolved materializations
1471 // that were inserted by the conversion driver. We keep track of these
1472 // ops in internal data structures. Erasing them must be allowed because
1473 // this can happen when the user is erasing an entire block (including
1474 // its body). But replacing them with another value should be forbidden
1475 // to avoid problems with the `mapping`.
1476 assert(!isUnresolvedMaterialization
&&
1477 "attempting to replace an unresolved materialization");
1480 // Remap result to replacement value.
1484 if (repl
.size() == 1) {
1485 // Single replacement value: replace directly.
1486 mapping
.map(result
, repl
.front());
1488 // Multiple replacement values: insert N:1 materialization.
1489 insertNTo1Materialization(computeInsertPoint(result
), result
.getLoc(),
1490 /*replacements=*/repl
, /*outputValue=*/result
,
1491 currentTypeConverter
);
1495 appendRewrite
<ReplaceOperationRewrite
>(op
, currentTypeConverter
);
1496 // Mark this operation and all nested ops as replaced.
1497 op
->walk([&](Operation
*op
) { replacedOps
.insert(op
); });
1500 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block
*block
) {
1501 appendRewrite
<EraseBlockRewrite
>(block
);
1504 void ConversionPatternRewriterImpl::notifyBlockInserted(
1505 Block
*block
, Region
*previous
, Region::iterator previousIt
) {
1506 assert(!wasOpReplaced(block
->getParentOp()) &&
1507 "attempting to insert into a region within a replaced/erased op");
1510 Operation
*parent
= block
->getParentOp();
1512 logger
.startLine() << "** Insert Block into : '" << parent
->getName()
1513 << "'(" << parent
<< ")\n";
1516 << "** Insert Block into detached Region (nullptr parent op)'";
1521 // This is a newly created block.
1522 appendRewrite
<CreateBlockRewrite
>(block
);
1525 Block
*prevBlock
= previousIt
== previous
->end() ? nullptr : &*previousIt
;
1526 appendRewrite
<MoveBlockRewrite
>(block
, previous
, prevBlock
);
1529 void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
1530 Block
*block
, Block
*srcBlock
, Block::iterator before
) {
1531 appendRewrite
<InlineBlockRewrite
>(block
, srcBlock
, before
);
1534 void ConversionPatternRewriterImpl::notifyMatchFailure(
1535 Location loc
, function_ref
<void(Diagnostic
&)> reasonCallback
) {
1537 Diagnostic
diag(loc
, DiagnosticSeverity::Remark
);
1538 reasonCallback(diag
);
1539 logger
.startLine() << "** Failure : " << diag
.str() << "\n";
1540 if (config
.notifyCallback
)
1541 config
.notifyCallback(diag
);
1545 //===----------------------------------------------------------------------===//
1546 // ConversionPatternRewriter
1547 //===----------------------------------------------------------------------===//
1549 ConversionPatternRewriter::ConversionPatternRewriter(
1550 MLIRContext
*ctx
, const ConversionConfig
&config
)
1551 : PatternRewriter(ctx
),
1552 impl(new detail::ConversionPatternRewriterImpl(ctx
, config
)) {
1553 setListener(impl
.get());
1556 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
1558 void ConversionPatternRewriter::replaceOp(Operation
*op
, Operation
*newOp
) {
1559 assert(op
&& newOp
&& "expected non-null op");
1560 replaceOp(op
, newOp
->getResults());
1563 void ConversionPatternRewriter::replaceOp(Operation
*op
, ValueRange newValues
) {
1564 assert(op
->getNumResults() == newValues
.size() &&
1565 "incorrect # of replacement values");
1567 impl
->logger
.startLine()
1568 << "** Replace : '" << op
->getName() << "'(" << op
<< ")\n";
1570 SmallVector
<ReplacementValues
> newVals(newValues
.size());
1571 for (auto [index
, val
] : llvm::enumerate(newValues
))
1573 newVals
[index
].push_back(val
);
1574 impl
->notifyOpReplaced(op
, newVals
);
1577 void ConversionPatternRewriter::replaceOpWithMultiple(
1578 Operation
*op
, ArrayRef
<ValueRange
> newValues
) {
1579 assert(op
->getNumResults() == newValues
.size() &&
1580 "incorrect # of replacement values");
1582 impl
->logger
.startLine()
1583 << "** Replace : '" << op
->getName() << "'(" << op
<< ")\n";
1585 SmallVector
<ReplacementValues
> newVals(newValues
.size(), {});
1586 for (auto [index
, val
] : llvm::enumerate(newValues
))
1587 llvm::append_range(newVals
[index
], val
);
1588 impl
->notifyOpReplaced(op
, newVals
);
1591 void ConversionPatternRewriter::eraseOp(Operation
*op
) {
1593 impl
->logger
.startLine()
1594 << "** Erase : '" << op
->getName() << "'(" << op
<< ")\n";
1596 SmallVector
<ReplacementValues
> nullRepls(op
->getNumResults(), {});
1597 impl
->notifyOpReplaced(op
, nullRepls
);
1600 void ConversionPatternRewriter::eraseBlock(Block
*block
) {
1601 assert(!impl
->wasOpReplaced(block
->getParentOp()) &&
1602 "attempting to erase a block within a replaced/erased op");
1604 // Mark all ops for erasure.
1605 for (Operation
&op
: *block
)
1608 // Unlink the block from its parent region. The block is kept in the rewrite
1609 // object and will be actually destroyed when rewrites are applied. This
1610 // allows us to keep the operations in the block live and undo the removal by
1611 // re-inserting the block.
1612 impl
->notifyBlockIsBeingErased(block
);
1613 block
->getParent()->getBlocks().remove(block
);
1616 Block
*ConversionPatternRewriter::applySignatureConversion(
1617 Block
*block
, TypeConverter::SignatureConversion
&conversion
,
1618 const TypeConverter
*converter
) {
1619 assert(!impl
->wasOpReplaced(block
->getParentOp()) &&
1620 "attempting to apply a signature conversion to a block within a "
1621 "replaced/erased op");
1622 return impl
->applySignatureConversion(*this, block
, converter
, conversion
);
1625 FailureOr
<Block
*> ConversionPatternRewriter::convertRegionTypes(
1626 Region
*region
, const TypeConverter
&converter
,
1627 TypeConverter::SignatureConversion
*entryConversion
) {
1628 assert(!impl
->wasOpReplaced(region
->getParentOp()) &&
1629 "attempting to apply a signature conversion to a block within a "
1630 "replaced/erased op");
1631 return impl
->convertRegionTypes(*this, region
, converter
, entryConversion
);
1634 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from
,
1637 Operation
*parentOp
= from
.getOwner()->getParentOp();
1638 impl
->logger
.startLine() << "** Replace Argument : '" << from
1639 << "'(in region of '" << parentOp
->getName()
1640 << "'(" << from
.getOwner()->getParentOp() << ")\n";
1642 impl
->appendRewrite
<ReplaceBlockArgRewrite
>(from
.getOwner(), from
,
1643 impl
->currentTypeConverter
);
1644 impl
->mapping
.map(impl
->mapping
.lookupOrDefault(from
), to
);
1647 Value
ConversionPatternRewriter::getRemappedValue(Value key
) {
1648 SmallVector
<Value
> remappedValues
;
1649 if (failed(impl
->remapValues("value", /*inputLoc=*/std::nullopt
, *this, key
,
1652 return remappedValues
.front();
1656 ConversionPatternRewriter::getRemappedValues(ValueRange keys
,
1657 SmallVectorImpl
<Value
> &results
) {
1660 return impl
->remapValues("value", /*inputLoc=*/std::nullopt
, *this, keys
,
1664 void ConversionPatternRewriter::inlineBlockBefore(Block
*source
, Block
*dest
,
1665 Block::iterator before
,
1666 ValueRange argValues
) {
1668 assert(argValues
.size() == source
->getNumArguments() &&
1669 "incorrect # of argument replacement values");
1670 assert(!impl
->wasOpReplaced(source
->getParentOp()) &&
1671 "attempting to inline a block from a replaced/erased op");
1672 assert(!impl
->wasOpReplaced(dest
->getParentOp()) &&
1673 "attempting to inline a block into a replaced/erased op");
1674 auto opIgnored
= [&](Operation
*op
) { return impl
->isOpIgnored(op
); };
1675 // The source block will be deleted, so it should not have any users (i.e.,
1676 // there should be no predecessors).
1677 assert(llvm::all_of(source
->getUsers(), opIgnored
) &&
1678 "expected 'source' to have no predecessors");
1681 // If a listener is attached to the dialect conversion, ops cannot be moved
1682 // to the destination block in bulk ("fast path"). This is because at the time
1683 // the notifications are sent, it is unknown which ops were moved. Instead,
1684 // ops should be moved one-by-one ("slow path"), so that a separate
1685 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1686 // a bit more efficient, so we try to do that when possible.
1687 bool fastPath
= !impl
->config
.listener
;
1690 impl
->notifyBlockBeingInlined(dest
, source
, before
);
1692 // Replace all uses of block arguments.
1693 for (auto it
: llvm::zip(source
->getArguments(), argValues
))
1694 replaceUsesOfBlockArgument(std::get
<0>(it
), std::get
<1>(it
));
1697 // Move all ops at once.
1698 dest
->getOperations().splice(before
, source
->getOperations());
1701 while (!source
->empty())
1702 moveOpBefore(&source
->front(), dest
, before
);
1705 // Erase the source block.
1709 void ConversionPatternRewriter::startOpModification(Operation
*op
) {
1710 assert(!impl
->wasOpReplaced(op
) &&
1711 "attempting to modify a replaced/erased op");
1713 impl
->pendingRootUpdates
.insert(op
);
1715 impl
->appendRewrite
<ModifyOperationRewrite
>(op
);
1718 void ConversionPatternRewriter::finalizeOpModification(Operation
*op
) {
1719 assert(!impl
->wasOpReplaced(op
) &&
1720 "attempting to modify a replaced/erased op");
1721 PatternRewriter::finalizeOpModification(op
);
1722 // There is nothing to do here, we only need to track the operation at the
1723 // start of the update.
1725 assert(impl
->pendingRootUpdates
.erase(op
) &&
1726 "operation did not have a pending in-place update");
1730 void ConversionPatternRewriter::cancelOpModification(Operation
*op
) {
1732 assert(impl
->pendingRootUpdates
.erase(op
) &&
1733 "operation did not have a pending in-place update");
1735 // Erase the last update for this operation.
1736 auto it
= llvm::find_if(
1737 llvm::reverse(impl
->rewrites
), [&](std::unique_ptr
<IRRewrite
> &rewrite
) {
1738 auto *modifyRewrite
= dyn_cast
<ModifyOperationRewrite
>(rewrite
.get());
1739 return modifyRewrite
&& modifyRewrite
->getOperation() == op
;
1741 assert(it
!= impl
->rewrites
.rend() && "no root update started on op");
1743 int updateIdx
= std::prev(impl
->rewrites
.rend()) - it
;
1744 impl
->rewrites
.erase(impl
->rewrites
.begin() + updateIdx
);
1747 detail::ConversionPatternRewriterImpl
&ConversionPatternRewriter::getImpl() {
1751 //===----------------------------------------------------------------------===//
1752 // ConversionPattern
1753 //===----------------------------------------------------------------------===//
1756 ConversionPattern::matchAndRewrite(Operation
*op
,
1757 PatternRewriter
&rewriter
) const {
1758 auto &dialectRewriter
= static_cast<ConversionPatternRewriter
&>(rewriter
);
1759 auto &rewriterImpl
= dialectRewriter
.getImpl();
1761 // Track the current conversion pattern type converter in the rewriter.
1762 llvm::SaveAndRestore
currentConverterGuard(rewriterImpl
.currentTypeConverter
,
1763 getTypeConverter());
1765 // Remap the operands of the operation.
1766 SmallVector
<Value
, 4> operands
;
1767 if (failed(rewriterImpl
.remapValues("operand", op
->getLoc(), rewriter
,
1768 op
->getOperands(), operands
))) {
1771 return matchAndRewrite(op
, operands
, dialectRewriter
);
1774 //===----------------------------------------------------------------------===//
1775 // OperationLegalizer
1776 //===----------------------------------------------------------------------===//
1779 /// A set of rewrite patterns that can be used to legalize a given operation.
1780 using LegalizationPatterns
= SmallVector
<const Pattern
*, 1>;
1782 /// This class defines a recursive operation legalizer.
1783 class OperationLegalizer
{
1785 using LegalizationAction
= ConversionTarget::LegalizationAction
;
1787 OperationLegalizer(const ConversionTarget
&targetInfo
,
1788 const FrozenRewritePatternSet
&patterns
,
1789 const ConversionConfig
&config
);
1791 /// Returns true if the given operation is known to be illegal on the target.
1792 bool isIllegal(Operation
*op
) const;
1794 /// Attempt to legalize the given operation. Returns success if the operation
1795 /// was legalized, failure otherwise.
1796 LogicalResult
legalize(Operation
*op
, ConversionPatternRewriter
&rewriter
);
1798 /// Returns the conversion target in use by the legalizer.
1799 const ConversionTarget
&getTarget() { return target
; }
1802 /// Attempt to legalize the given operation by folding it.
1803 LogicalResult
legalizeWithFold(Operation
*op
,
1804 ConversionPatternRewriter
&rewriter
);
1806 /// Attempt to legalize the given operation by applying a pattern. Returns
1807 /// success if the operation was legalized, failure otherwise.
1808 LogicalResult
legalizeWithPattern(Operation
*op
,
1809 ConversionPatternRewriter
&rewriter
);
1811 /// Return true if the given pattern may be applied to the given operation,
1812 /// false otherwise.
1813 bool canApplyPattern(Operation
*op
, const Pattern
&pattern
,
1814 ConversionPatternRewriter
&rewriter
);
1816 /// Legalize the resultant IR after successfully applying the given pattern.
1817 LogicalResult
legalizePatternResult(Operation
*op
, const Pattern
&pattern
,
1818 ConversionPatternRewriter
&rewriter
,
1819 RewriterState
&curState
);
1821 /// Legalizes the actions registered during the execution of a pattern.
1823 legalizePatternBlockRewrites(Operation
*op
,
1824 ConversionPatternRewriter
&rewriter
,
1825 ConversionPatternRewriterImpl
&impl
,
1826 RewriterState
&state
, RewriterState
&newState
);
1827 LogicalResult
legalizePatternCreatedOperations(
1828 ConversionPatternRewriter
&rewriter
, ConversionPatternRewriterImpl
&impl
,
1829 RewriterState
&state
, RewriterState
&newState
);
1830 LogicalResult
legalizePatternRootUpdates(ConversionPatternRewriter
&rewriter
,
1831 ConversionPatternRewriterImpl
&impl
,
1832 RewriterState
&state
,
1833 RewriterState
&newState
);
1835 //===--------------------------------------------------------------------===//
1837 //===--------------------------------------------------------------------===//
1839 /// Build an optimistic legalization graph given the provided patterns. This
1840 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1841 /// patterns for operations that are not directly legal, but may be
1842 /// transitively legal for the current target given the provided patterns.
1843 void buildLegalizationGraph(
1844 LegalizationPatterns
&anyOpLegalizerPatterns
,
1845 DenseMap
<OperationName
, LegalizationPatterns
> &legalizerPatterns
);
1847 /// Compute the benefit of each node within the computed legalization graph.
1848 /// This orders the patterns within 'legalizerPatterns' based upon two
1850 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1851 /// represent the more direct mapping to the target.
1852 /// 2) When comparing patterns with the same legalization depth, prefer the
1853 /// pattern with the highest PatternBenefit. This allows for users to
1854 /// prefer specific legalizations over others.
1855 void computeLegalizationGraphBenefit(
1856 LegalizationPatterns
&anyOpLegalizerPatterns
,
1857 DenseMap
<OperationName
, LegalizationPatterns
> &legalizerPatterns
);
1859 /// Compute the legalization depth when legalizing an operation of the given
1861 unsigned computeOpLegalizationDepth(
1862 OperationName op
, DenseMap
<OperationName
, unsigned> &minOpPatternDepth
,
1863 DenseMap
<OperationName
, LegalizationPatterns
> &legalizerPatterns
);
1865 /// Apply the conversion cost model to the given set of patterns, and return
1866 /// the smallest legalization depth of any of the patterns. See
1867 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1868 unsigned applyCostModelToPatterns(
1869 LegalizationPatterns
&patterns
,
1870 DenseMap
<OperationName
, unsigned> &minOpPatternDepth
,
1871 DenseMap
<OperationName
, LegalizationPatterns
> &legalizerPatterns
);
1873 /// The current set of patterns that have been applied.
1874 SmallPtrSet
<const Pattern
*, 8> appliedPatterns
;
1876 /// The legalization information provided by the target.
1877 const ConversionTarget
&target
;
1879 /// The pattern applicator to use for conversions.
1880 PatternApplicator applicator
;
1882 /// Dialect conversion configuration.
1883 const ConversionConfig
&config
;
1887 OperationLegalizer::OperationLegalizer(const ConversionTarget
&targetInfo
,
1888 const FrozenRewritePatternSet
&patterns
,
1889 const ConversionConfig
&config
)
1890 : target(targetInfo
), applicator(patterns
), config(config
) {
1891 // The set of patterns that can be applied to illegal operations to transform
1892 // them into legal ones.
1893 DenseMap
<OperationName
, LegalizationPatterns
> legalizerPatterns
;
1894 LegalizationPatterns anyOpLegalizerPatterns
;
1896 buildLegalizationGraph(anyOpLegalizerPatterns
, legalizerPatterns
);
1897 computeLegalizationGraphBenefit(anyOpLegalizerPatterns
, legalizerPatterns
);
1900 bool OperationLegalizer::isIllegal(Operation
*op
) const {
1901 return target
.isIllegal(op
);
1905 OperationLegalizer::legalize(Operation
*op
,
1906 ConversionPatternRewriter
&rewriter
) {
1908 const char *logLineComment
=
1909 "//===-------------------------------------------===//\n";
1911 auto &logger
= rewriter
.getImpl().logger
;
1914 logger
.getOStream() << "\n";
1915 logger
.startLine() << logLineComment
;
1916 logger
.startLine() << "Legalizing operation : '" << op
->getName() << "'("
1920 // If the operation has no regions, just print it here.
1921 if (op
->getNumRegions() == 0) {
1922 op
->print(logger
.startLine(), OpPrintingFlags().printGenericOpForm());
1923 logger
.getOStream() << "\n\n";
1927 // Check if this operation is legal on the target.
1928 if (auto legalityInfo
= target
.isLegal(op
)) {
1931 logger
, "operation marked legal by the target{0}",
1932 legalityInfo
->isRecursivelyLegal
1933 ? "; NOTE: operation is recursively legal; skipping internals"
1935 logger
.startLine() << logLineComment
;
1938 // If this operation is recursively legal, mark its children as ignored so
1939 // that we don't consider them for legalization.
1940 if (legalityInfo
->isRecursivelyLegal
) {
1941 op
->walk([&](Operation
*nested
) {
1943 rewriter
.getImpl().ignoredOps
.insert(nested
);
1950 // Check to see if the operation is ignored and doesn't need to be converted.
1951 if (rewriter
.getImpl().isOpIgnored(op
)) {
1953 logSuccess(logger
, "operation marked 'ignored' during conversion");
1954 logger
.startLine() << logLineComment
;
1959 // If the operation isn't legal, try to fold it in-place.
1960 // TODO: Should we always try to do this, even if the op is
1962 if (succeeded(legalizeWithFold(op
, rewriter
))) {
1964 logSuccess(logger
, "operation was folded");
1965 logger
.startLine() << logLineComment
;
1970 // Otherwise, we need to apply a legalization pattern to this operation.
1971 if (succeeded(legalizeWithPattern(op
, rewriter
))) {
1973 logSuccess(logger
, "");
1974 logger
.startLine() << logLineComment
;
1980 logFailure(logger
, "no matched legalization pattern");
1981 logger
.startLine() << logLineComment
;
1987 OperationLegalizer::legalizeWithFold(Operation
*op
,
1988 ConversionPatternRewriter
&rewriter
) {
1989 auto &rewriterImpl
= rewriter
.getImpl();
1990 RewriterState curState
= rewriterImpl
.getCurrentState();
1993 rewriterImpl
.logger
.startLine() << "* Fold {\n";
1994 rewriterImpl
.logger
.indent();
1997 // Try to fold the operation.
1998 SmallVector
<Value
, 2> replacementValues
;
1999 rewriter
.setInsertionPoint(op
);
2000 if (failed(rewriter
.tryFold(op
, replacementValues
))) {
2001 LLVM_DEBUG(logFailure(rewriterImpl
.logger
, "unable to fold"));
2004 // An empty list of replacement values indicates that the fold was in-place.
2005 // As the operation changed, a new legalization needs to be attempted.
2006 if (replacementValues
.empty())
2007 return legalize(op
, rewriter
);
2009 // Insert a replacement for 'op' with the folded replacement values.
2010 rewriter
.replaceOp(op
, replacementValues
);
2012 // Recursively legalize any new constant operations.
2013 for (unsigned i
= curState
.numRewrites
, e
= rewriterImpl
.rewrites
.size();
2016 dyn_cast
<CreateOperationRewrite
>(rewriterImpl
.rewrites
[i
].get());
2019 if (failed(legalize(createOp
->getOperation(), rewriter
))) {
2020 LLVM_DEBUG(logFailure(rewriterImpl
.logger
,
2021 "failed to legalize generated constant '{0}'",
2022 createOp
->getOperation()->getName()));
2023 rewriterImpl
.resetState(curState
);
2028 LLVM_DEBUG(logSuccess(rewriterImpl
.logger
, ""));
2033 OperationLegalizer::legalizeWithPattern(Operation
*op
,
2034 ConversionPatternRewriter
&rewriter
) {
2035 auto &rewriterImpl
= rewriter
.getImpl();
2037 // Functor that returns if the given pattern may be applied.
2038 auto canApply
= [&](const Pattern
&pattern
) {
2039 bool canApply
= canApplyPattern(op
, pattern
, rewriter
);
2040 if (canApply
&& config
.listener
)
2041 config
.listener
->notifyPatternBegin(pattern
, op
);
2045 // Functor that cleans up the rewriter state after a pattern failed to match.
2046 RewriterState curState
= rewriterImpl
.getCurrentState();
2047 auto onFailure
= [&](const Pattern
&pattern
) {
2048 assert(rewriterImpl
.pendingRootUpdates
.empty() && "dangling root updates");
2050 logFailure(rewriterImpl
.logger
, "pattern failed to match");
2051 if (rewriterImpl
.config
.notifyCallback
) {
2052 Diagnostic
diag(op
->getLoc(), DiagnosticSeverity::Remark
);
2053 diag
<< "Failed to apply pattern \"" << pattern
.getDebugName()
2056 rewriterImpl
.config
.notifyCallback(diag
);
2059 if (config
.listener
)
2060 config
.listener
->notifyPatternEnd(pattern
, failure());
2061 rewriterImpl
.resetState(curState
);
2062 appliedPatterns
.erase(&pattern
);
2065 // Functor that performs additional legalization when a pattern is
2066 // successfully applied.
2067 auto onSuccess
= [&](const Pattern
&pattern
) {
2068 assert(rewriterImpl
.pendingRootUpdates
.empty() && "dangling root updates");
2069 auto result
= legalizePatternResult(op
, pattern
, rewriter
, curState
);
2070 appliedPatterns
.erase(&pattern
);
2072 rewriterImpl
.resetState(curState
);
2073 if (config
.listener
)
2074 config
.listener
->notifyPatternEnd(pattern
, result
);
2078 // Try to match and rewrite a pattern on this operation.
2079 return applicator
.matchAndRewrite(op
, rewriter
, canApply
, onFailure
,
2083 bool OperationLegalizer::canApplyPattern(Operation
*op
, const Pattern
&pattern
,
2084 ConversionPatternRewriter
&rewriter
) {
2086 auto &os
= rewriter
.getImpl().logger
;
2087 os
.getOStream() << "\n";
2088 os
.startLine() << "* Pattern : '" << op
->getName() << " -> (";
2089 llvm::interleaveComma(pattern
.getGeneratedOps(), os
.getOStream());
2090 os
.getOStream() << ")' {\n";
2094 // Ensure that we don't cycle by not allowing the same pattern to be
2095 // applied twice in the same recursion stack if it is not known to be safe.
2096 if (!pattern
.hasBoundedRewriteRecursion() &&
2097 !appliedPatterns
.insert(&pattern
).second
) {
2099 logFailure(rewriter
.getImpl().logger
, "pattern was already applied"));
2106 OperationLegalizer::legalizePatternResult(Operation
*op
, const Pattern
&pattern
,
2107 ConversionPatternRewriter
&rewriter
,
2108 RewriterState
&curState
) {
2109 auto &impl
= rewriter
.getImpl();
2112 assert(impl
.pendingRootUpdates
.empty() && "dangling root updates");
2113 // Check that the root was either replaced or updated in place.
2114 auto newRewrites
= llvm::drop_begin(impl
.rewrites
, curState
.numRewrites
);
2115 auto replacedRoot
= [&] {
2116 return hasRewrite
<ReplaceOperationRewrite
>(newRewrites
, op
);
2118 auto updatedRootInPlace
= [&] {
2119 return hasRewrite
<ModifyOperationRewrite
>(newRewrites
, op
);
2121 assert((replacedRoot() || updatedRootInPlace()) &&
2122 "expected pattern to replace the root operation");
2125 // Legalize each of the actions registered during application.
2126 RewriterState newState
= impl
.getCurrentState();
2127 if (failed(legalizePatternBlockRewrites(op
, rewriter
, impl
, curState
,
2129 failed(legalizePatternRootUpdates(rewriter
, impl
, curState
, newState
)) ||
2130 failed(legalizePatternCreatedOperations(rewriter
, impl
, curState
,
2135 LLVM_DEBUG(logSuccess(impl
.logger
, "pattern applied successfully"));
2139 LogicalResult
OperationLegalizer::legalizePatternBlockRewrites(
2140 Operation
*op
, ConversionPatternRewriter
&rewriter
,
2141 ConversionPatternRewriterImpl
&impl
, RewriterState
&state
,
2142 RewriterState
&newState
) {
2143 SmallPtrSet
<Operation
*, 16> operationsToIgnore
;
2145 // If the pattern moved or created any blocks, make sure the types of block
2146 // arguments get legalized.
2147 for (int i
= state
.numRewrites
, e
= newState
.numRewrites
; i
!= e
; ++i
) {
2148 BlockRewrite
*rewrite
= dyn_cast
<BlockRewrite
>(impl
.rewrites
[i
].get());
2151 Block
*block
= rewrite
->getBlock();
2152 if (isa
<BlockTypeConversionRewrite
, EraseBlockRewrite
,
2153 ReplaceBlockArgRewrite
>(rewrite
))
2155 // Only check blocks outside of the current operation.
2156 Operation
*parentOp
= block
->getParentOp();
2157 if (!parentOp
|| parentOp
== op
|| block
->getNumArguments() == 0)
2160 // If the region of the block has a type converter, try to convert the block
2162 if (auto *converter
= impl
.regionToConverter
.lookup(block
->getParent())) {
2163 std::optional
<TypeConverter::SignatureConversion
> conversion
=
2164 converter
->convertBlockSignature(block
);
2166 LLVM_DEBUG(logFailure(impl
.logger
, "failed to convert types of moved "
2170 impl
.applySignatureConversion(rewriter
, block
, converter
, *conversion
);
2174 // Otherwise, check that this operation isn't one generated by this pattern.
2175 // This is because we will attempt to legalize the parent operation, and
2176 // blocks in regions created by this pattern will already be legalized later
2177 // on. If we haven't built the set yet, build it now.
2178 if (operationsToIgnore
.empty()) {
2179 for (unsigned i
= state
.numRewrites
, e
= impl
.rewrites
.size(); i
!= e
;
2182 dyn_cast
<CreateOperationRewrite
>(impl
.rewrites
[i
].get());
2185 operationsToIgnore
.insert(createOp
->getOperation());
2189 // If this operation should be considered for re-legalization, try it.
2190 if (operationsToIgnore
.insert(parentOp
).second
&&
2191 failed(legalize(parentOp
, rewriter
))) {
2192 LLVM_DEBUG(logFailure(impl
.logger
,
2193 "operation '{0}'({1}) became illegal after rewrite",
2194 parentOp
->getName(), parentOp
));
2201 LogicalResult
OperationLegalizer::legalizePatternCreatedOperations(
2202 ConversionPatternRewriter
&rewriter
, ConversionPatternRewriterImpl
&impl
,
2203 RewriterState
&state
, RewriterState
&newState
) {
2204 for (int i
= state
.numRewrites
, e
= newState
.numRewrites
; i
!= e
; ++i
) {
2205 auto *createOp
= dyn_cast
<CreateOperationRewrite
>(impl
.rewrites
[i
].get());
2208 Operation
*op
= createOp
->getOperation();
2209 if (failed(legalize(op
, rewriter
))) {
2210 LLVM_DEBUG(logFailure(impl
.logger
,
2211 "failed to legalize generated operation '{0}'({1})",
2212 op
->getName(), op
));
2219 LogicalResult
OperationLegalizer::legalizePatternRootUpdates(
2220 ConversionPatternRewriter
&rewriter
, ConversionPatternRewriterImpl
&impl
,
2221 RewriterState
&state
, RewriterState
&newState
) {
2222 for (int i
= state
.numRewrites
, e
= newState
.numRewrites
; i
!= e
; ++i
) {
2223 auto *rewrite
= dyn_cast
<ModifyOperationRewrite
>(impl
.rewrites
[i
].get());
2226 Operation
*op
= rewrite
->getOperation();
2227 if (failed(legalize(op
, rewriter
))) {
2228 LLVM_DEBUG(logFailure(
2229 impl
.logger
, "failed to legalize operation updated in-place '{0}'",
2237 //===----------------------------------------------------------------------===//
2240 void OperationLegalizer::buildLegalizationGraph(
2241 LegalizationPatterns
&anyOpLegalizerPatterns
,
2242 DenseMap
<OperationName
, LegalizationPatterns
> &legalizerPatterns
) {
2243 // A mapping between an operation and a set of operations that can be used to
2245 DenseMap
<OperationName
, SmallPtrSet
<OperationName
, 2>> parentOps
;
2246 // A mapping between an operation and any currently invalid patterns it has.
2247 DenseMap
<OperationName
, SmallPtrSet
<const Pattern
*, 2>> invalidPatterns
;
2248 // A worklist of patterns to consider for legality.
2249 SetVector
<const Pattern
*> patternWorklist
;
2251 // Build the mapping from operations to the parent ops that may generate them.
2252 applicator
.walkAllPatterns([&](const Pattern
&pattern
) {
2253 std::optional
<OperationName
> root
= pattern
.getRootKind();
2255 // If the pattern has no specific root, we can't analyze the relationship
2256 // between the root op and generated operations. Given that, add all such
2257 // patterns to the legalization set.
2259 anyOpLegalizerPatterns
.push_back(&pattern
);
2263 // Skip operations that are always known to be legal.
2264 if (target
.getOpAction(*root
) == LegalizationAction::Legal
)
2267 // Add this pattern to the invalid set for the root op and record this root
2268 // as a parent for any generated operations.
2269 invalidPatterns
[*root
].insert(&pattern
);
2270 for (auto op
: pattern
.getGeneratedOps())
2271 parentOps
[op
].insert(*root
);
2273 // Add this pattern to the worklist.
2274 patternWorklist
.insert(&pattern
);
2277 // If there are any patterns that don't have a specific root kind, we can't
2278 // make direct assumptions about what operations will never be legalized.
2279 // Note: Technically we could, but it would require an analysis that may
2280 // recurse into itself. It would be better to perform this kind of filtering
2281 // at a higher level than here anyways.
2282 if (!anyOpLegalizerPatterns
.empty()) {
2283 for (const Pattern
*pattern
: patternWorklist
)
2284 legalizerPatterns
[*pattern
->getRootKind()].push_back(pattern
);
2288 while (!patternWorklist
.empty()) {
2289 auto *pattern
= patternWorklist
.pop_back_val();
2291 // Check to see if any of the generated operations are invalid.
2292 if (llvm::any_of(pattern
->getGeneratedOps(), [&](OperationName op
) {
2293 std::optional
<LegalizationAction
> action
= target
.getOpAction(op
);
2294 return !legalizerPatterns
.count(op
) &&
2295 (!action
|| action
== LegalizationAction::Illegal
);
2299 // Otherwise, if all of the generated operation are valid, this op is now
2300 // legal so add all of the child patterns to the worklist.
2301 legalizerPatterns
[*pattern
->getRootKind()].push_back(pattern
);
2302 invalidPatterns
[*pattern
->getRootKind()].erase(pattern
);
2304 // Add any invalid patterns of the parent operations to see if they have now
2306 for (auto op
: parentOps
[*pattern
->getRootKind()])
2307 patternWorklist
.set_union(invalidPatterns
[op
]);
2311 void OperationLegalizer::computeLegalizationGraphBenefit(
2312 LegalizationPatterns
&anyOpLegalizerPatterns
,
2313 DenseMap
<OperationName
, LegalizationPatterns
> &legalizerPatterns
) {
2314 // The smallest pattern depth, when legalizing an operation.
2315 DenseMap
<OperationName
, unsigned> minOpPatternDepth
;
2317 // For each operation that is transitively legal, compute a cost for it.
2318 for (auto &opIt
: legalizerPatterns
)
2319 if (!minOpPatternDepth
.count(opIt
.first
))
2320 computeOpLegalizationDepth(opIt
.first
, minOpPatternDepth
,
2323 // Apply the cost model to the patterns that can match any operation. Those
2324 // with a specific operation type are already resolved when computing the op
2325 // legalization depth.
2326 if (!anyOpLegalizerPatterns
.empty())
2327 applyCostModelToPatterns(anyOpLegalizerPatterns
, minOpPatternDepth
,
2330 // Apply a cost model to the pattern applicator. We order patterns first by
2331 // depth then benefit. `legalizerPatterns` contains per-op patterns by
2332 // decreasing benefit.
2333 applicator
.applyCostModel([&](const Pattern
&pattern
) {
2334 ArrayRef
<const Pattern
*> orderedPatternList
;
2335 if (std::optional
<OperationName
> rootName
= pattern
.getRootKind())
2336 orderedPatternList
= legalizerPatterns
[*rootName
];
2338 orderedPatternList
= anyOpLegalizerPatterns
;
2340 // If the pattern is not found, then it was removed and cannot be matched.
2341 auto *it
= llvm::find(orderedPatternList
, &pattern
);
2342 if (it
== orderedPatternList
.end())
2343 return PatternBenefit::impossibleToMatch();
2345 // Patterns found earlier in the list have higher benefit.
2346 return PatternBenefit(std::distance(it
, orderedPatternList
.end()));
2350 unsigned OperationLegalizer::computeOpLegalizationDepth(
2351 OperationName op
, DenseMap
<OperationName
, unsigned> &minOpPatternDepth
,
2352 DenseMap
<OperationName
, LegalizationPatterns
> &legalizerPatterns
) {
2353 // Check for existing depth.
2354 auto depthIt
= minOpPatternDepth
.find(op
);
2355 if (depthIt
!= minOpPatternDepth
.end())
2356 return depthIt
->second
;
2358 // If a mapping for this operation does not exist, then this operation
2359 // is always legal. Return 0 as the depth for a directly legal operation.
2360 auto opPatternsIt
= legalizerPatterns
.find(op
);
2361 if (opPatternsIt
== legalizerPatterns
.end() || opPatternsIt
->second
.empty())
2364 // Record this initial depth in case we encounter this op again when
2365 // recursively computing the depth.
2366 minOpPatternDepth
.try_emplace(op
, std::numeric_limits
<unsigned>::max());
2368 // Apply the cost model to the operation patterns, and update the minimum
2370 unsigned minDepth
= applyCostModelToPatterns(
2371 opPatternsIt
->second
, minOpPatternDepth
, legalizerPatterns
);
2372 minOpPatternDepth
[op
] = minDepth
;
2376 unsigned OperationLegalizer::applyCostModelToPatterns(
2377 LegalizationPatterns
&patterns
,
2378 DenseMap
<OperationName
, unsigned> &minOpPatternDepth
,
2379 DenseMap
<OperationName
, LegalizationPatterns
> &legalizerPatterns
) {
2380 unsigned minDepth
= std::numeric_limits
<unsigned>::max();
2382 // Compute the depth for each pattern within the set.
2383 SmallVector
<std::pair
<const Pattern
*, unsigned>, 4> patternsByDepth
;
2384 patternsByDepth
.reserve(patterns
.size());
2385 for (const Pattern
*pattern
: patterns
) {
2387 for (auto generatedOp
: pattern
->getGeneratedOps()) {
2388 unsigned generatedOpDepth
= computeOpLegalizationDepth(
2389 generatedOp
, minOpPatternDepth
, legalizerPatterns
);
2390 depth
= std::max(depth
, generatedOpDepth
+ 1);
2392 patternsByDepth
.emplace_back(pattern
, depth
);
2394 // Update the minimum depth of the pattern list.
2395 minDepth
= std::min(minDepth
, depth
);
2398 // If the operation only has one legalization pattern, there is no need to
2400 if (patternsByDepth
.size() == 1)
2403 // Sort the patterns by those likely to be the most beneficial.
2404 std::stable_sort(patternsByDepth
.begin(), patternsByDepth
.end(),
2405 [](const std::pair
<const Pattern
*, unsigned> &lhs
,
2406 const std::pair
<const Pattern
*, unsigned> &rhs
) {
2407 // First sort by the smaller pattern legalization
2409 if (lhs
.second
!= rhs
.second
)
2410 return lhs
.second
< rhs
.second
;
2412 // Then sort by the larger pattern benefit.
2413 auto lhsBenefit
= lhs
.first
->getBenefit();
2414 auto rhsBenefit
= rhs
.first
->getBenefit();
2415 return lhsBenefit
> rhsBenefit
;
2418 // Update the legalization pattern to use the new sorted list.
2420 for (auto &patternIt
: patternsByDepth
)
2421 patterns
.push_back(patternIt
.first
);
2425 //===----------------------------------------------------------------------===//
2426 // OperationConverter
2427 //===----------------------------------------------------------------------===//
2429 enum OpConversionMode
{
2430 /// In this mode, the conversion will ignore failed conversions to allow
2431 /// illegal operations to co-exist in the IR.
2434 /// In this mode, all operations must be legal for the given target for the
2435 /// conversion to succeed.
2438 /// In this mode, operations are analyzed for legality. No actual rewrites are
2439 /// applied to the operations on success.
2445 // This class converts operations to a given conversion target via a set of
2446 // rewrite patterns. The conversion behaves differently depending on the
2448 struct OperationConverter
{
2449 explicit OperationConverter(const ConversionTarget
&target
,
2450 const FrozenRewritePatternSet
&patterns
,
2451 const ConversionConfig
&config
,
2452 OpConversionMode mode
)
2453 : config(config
), opLegalizer(target
, patterns
, this->config
),
2456 /// Converts the given operations to the conversion target.
2457 LogicalResult
convertOperations(ArrayRef
<Operation
*> ops
);
2460 /// Converts an operation with the given rewriter.
2461 LogicalResult
convert(ConversionPatternRewriter
&rewriter
, Operation
*op
);
2463 /// Dialect conversion configuration.
2464 ConversionConfig config
;
2466 /// The legalizer to use when converting operations.
2467 OperationLegalizer opLegalizer
;
2469 /// The conversion mode to use when legalizing operations.
2470 OpConversionMode mode
;
2474 LogicalResult
OperationConverter::convert(ConversionPatternRewriter
&rewriter
,
2476 // Legalize the given operation.
2477 if (failed(opLegalizer
.legalize(op
, rewriter
))) {
2478 // Handle the case of a failed conversion for each of the different modes.
2479 // Full conversions expect all operations to be converted.
2480 if (mode
== OpConversionMode::Full
)
2481 return op
->emitError()
2482 << "failed to legalize operation '" << op
->getName() << "'";
2483 // Partial conversions allow conversions to fail iff the operation was not
2484 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2485 // set, non-legalizable ops are added to that set.
2486 if (mode
== OpConversionMode::Partial
) {
2487 if (opLegalizer
.isIllegal(op
))
2488 return op
->emitError()
2489 << "failed to legalize operation '" << op
->getName()
2490 << "' that was explicitly marked illegal";
2491 if (config
.unlegalizedOps
)
2492 config
.unlegalizedOps
->insert(op
);
2494 } else if (mode
== OpConversionMode::Analysis
) {
2495 // Analysis conversions don't fail if any operations fail to legalize,
2496 // they are only interested in the operations that were successfully
2498 if (config
.legalizableOps
)
2499 config
.legalizableOps
->insert(op
);
2504 static LogicalResult
2505 legalizeUnresolvedMaterialization(RewriterBase
&rewriter
,
2506 UnresolvedMaterializationRewrite
*rewrite
) {
2507 UnrealizedConversionCastOp op
= rewrite
->getOperation();
2508 assert(!op
.use_empty() &&
2509 "expected that dead materializations have already been DCE'd");
2510 Operation::operand_range inputOperands
= op
.getOperands();
2511 Type outputType
= op
.getResultTypes()[0];
2513 // Try to materialize the conversion.
2514 if (const TypeConverter
*converter
= rewrite
->getConverter()) {
2515 rewriter
.setInsertionPoint(op
);
2516 Value newMaterialization
;
2517 switch (rewrite
->getMaterializationKind()) {
2518 case MaterializationKind::Argument
:
2519 // Try to materialize an argument conversion.
2520 newMaterialization
= converter
->materializeArgumentConversion(
2521 rewriter
, op
->getLoc(), outputType
, inputOperands
);
2522 if (newMaterialization
)
2524 // If an argument materialization failed, fallback to trying a target
2527 case MaterializationKind::Target
:
2528 newMaterialization
= converter
->materializeTargetConversion(
2529 rewriter
, op
->getLoc(), outputType
, inputOperands
,
2530 rewrite
->getOriginalType());
2532 case MaterializationKind::Source
:
2533 newMaterialization
= converter
->materializeSourceConversion(
2534 rewriter
, op
->getLoc(), outputType
, inputOperands
);
2537 if (newMaterialization
) {
2538 assert(newMaterialization
.getType() == outputType
&&
2539 "materialization callback produced value of incorrect type");
2540 rewriter
.replaceOp(op
, newMaterialization
);
2545 InFlightDiagnostic diag
=
2546 op
->emitError() << "failed to legalize unresolved materialization "
2548 << inputOperands
.getTypes() << ") to (" << outputType
2549 << ") that remained live after conversion";
2550 diag
.attachNote(op
->getUsers().begin()->getLoc())
2551 << "see existing live user here: " << *op
->getUsers().begin();
2555 LogicalResult
OperationConverter::convertOperations(ArrayRef
<Operation
*> ops
) {
2558 const ConversionTarget
&target
= opLegalizer
.getTarget();
2560 // Compute the set of operations and blocks to convert.
2561 SmallVector
<Operation
*> toConvert
;
2562 for (auto *op
: ops
) {
2563 op
->walk
<WalkOrder::PreOrder
, ForwardDominanceIterator
<>>(
2564 [&](Operation
*op
) {
2565 toConvert
.push_back(op
);
2566 // Don't check this operation's children for conversion if the
2567 // operation is recursively legal.
2568 auto legalityInfo
= target
.isLegal(op
);
2569 if (legalityInfo
&& legalityInfo
->isRecursivelyLegal
)
2570 return WalkResult::skip();
2571 return WalkResult::advance();
2575 // Convert each operation and discard rewrites on failure.
2576 ConversionPatternRewriter
rewriter(ops
.front()->getContext(), config
);
2577 ConversionPatternRewriterImpl
&rewriterImpl
= rewriter
.getImpl();
2579 for (auto *op
: toConvert
)
2580 if (failed(convert(rewriter
, op
)))
2581 return rewriterImpl
.undoRewrites(), failure();
2583 // After a successful conversion, apply rewrites.
2584 rewriterImpl
.applyRewrites();
2586 // Gather all unresolved materializations.
2587 SmallVector
<UnrealizedConversionCastOp
> allCastOps
;
2588 const DenseMap
<UnrealizedConversionCastOp
, UnresolvedMaterializationRewrite
*>
2589 &materializations
= rewriterImpl
.unresolvedMaterializations
;
2590 for (auto it
: materializations
) {
2591 if (rewriterImpl
.eraseRewriter
.wasErased(it
.first
))
2593 allCastOps
.push_back(it
.first
);
2596 // Reconcile all UnrealizedConversionCastOps that were inserted by the
2597 // dialect conversion frameworks. (Not the one that were inserted by
2599 SmallVector
<UnrealizedConversionCastOp
> remainingCastOps
;
2600 reconcileUnrealizedCasts(allCastOps
, &remainingCastOps
);
2602 // Try to legalize all unresolved materializations.
2603 if (config
.buildMaterializations
) {
2604 IRRewriter
rewriter(rewriterImpl
.context
, config
.listener
);
2605 for (UnrealizedConversionCastOp castOp
: remainingCastOps
) {
2606 auto it
= materializations
.find(castOp
);
2607 assert(it
!= materializations
.end() && "inconsistent state");
2608 if (failed(legalizeUnresolvedMaterialization(rewriter
, it
->second
)))
2616 //===----------------------------------------------------------------------===//
2617 // Reconcile Unrealized Casts
2618 //===----------------------------------------------------------------------===//
2620 void mlir::reconcileUnrealizedCasts(
2621 ArrayRef
<UnrealizedConversionCastOp
> castOps
,
2622 SmallVectorImpl
<UnrealizedConversionCastOp
> *remainingCastOps
) {
2623 SetVector
<UnrealizedConversionCastOp
> worklist(castOps
.begin(),
2625 // This set is maintained only if `remainingCastOps` is provided.
2626 DenseSet
<Operation
*> erasedOps
;
2628 // Helper function that adds all operands to the worklist that are an
2629 // unrealized_conversion_cast op result.
2630 auto enqueueOperands
= [&](UnrealizedConversionCastOp castOp
) {
2631 for (Value v
: castOp
.getInputs())
2632 if (auto inputCastOp
= v
.getDefiningOp
<UnrealizedConversionCastOp
>())
2633 worklist
.insert(inputCastOp
);
2636 // Helper function that return the unrealized_conversion_cast op that
2637 // defines all inputs of the given op (in the same order). Return "nullptr"
2638 // if there is no such op.
2640 [](UnrealizedConversionCastOp castOp
) -> UnrealizedConversionCastOp
{
2641 if (castOp
.getInputs().empty())
2644 castOp
.getInputs().front().getDefiningOp
<UnrealizedConversionCastOp
>();
2647 if (inputCastOp
.getOutputs() != castOp
.getInputs())
2652 // Process ops in the worklist bottom-to-top.
2653 while (!worklist
.empty()) {
2654 UnrealizedConversionCastOp castOp
= worklist
.pop_back_val();
2655 if (castOp
->use_empty()) {
2656 // DCE: If the op has no users, erase it. Add the operands to the
2657 // worklist to find additional DCE opportunities.
2658 enqueueOperands(castOp
);
2659 if (remainingCastOps
)
2660 erasedOps
.insert(castOp
.getOperation());
2665 // Traverse the chain of input cast ops to see if an op with the same
2666 // input types can be found.
2667 UnrealizedConversionCastOp nextCast
= castOp
;
2669 if (nextCast
.getInputs().getTypes() == castOp
.getResultTypes()) {
2670 // Found a cast where the input types match the output types of the
2671 // matched op. We can directly use those inputs and the matched op can
2673 enqueueOperands(castOp
);
2674 castOp
.replaceAllUsesWith(nextCast
.getInputs());
2675 if (remainingCastOps
)
2676 erasedOps
.insert(castOp
.getOperation());
2680 nextCast
= getInputCast(nextCast
);
2684 if (remainingCastOps
)
2685 for (UnrealizedConversionCastOp op
: castOps
)
2686 if (!erasedOps
.contains(op
.getOperation()))
2687 remainingCastOps
->push_back(op
);
2690 //===----------------------------------------------------------------------===//
2692 //===----------------------------------------------------------------------===//
2694 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo
,
2695 ArrayRef
<Type
> types
) {
2696 assert(!types
.empty() && "expected valid types");
2697 remapInput(origInputNo
, /*newInputNo=*/argTypes
.size(), types
.size());
2701 void TypeConverter::SignatureConversion::addInputs(ArrayRef
<Type
> types
) {
2702 assert(!types
.empty() &&
2703 "1->0 type remappings don't need to be added explicitly");
2704 argTypes
.append(types
.begin(), types
.end());
2707 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo
,
2708 unsigned newInputNo
,
2709 unsigned newInputCount
) {
2710 assert(!remappedInputs
[origInputNo
] && "input has already been remapped");
2711 assert(newInputCount
!= 0 && "expected valid input count");
2712 remappedInputs
[origInputNo
] =
2713 InputMapping
{newInputNo
, newInputCount
, /*replacementValue=*/nullptr};
2716 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo
,
2717 Value replacementValue
) {
2718 assert(!remappedInputs
[origInputNo
] && "input has already been remapped");
2719 remappedInputs
[origInputNo
] =
2720 InputMapping
{origInputNo
, /*size=*/0, replacementValue
};
2723 LogicalResult
TypeConverter::convertType(Type t
,
2724 SmallVectorImpl
<Type
> &results
) const {
2726 std::shared_lock
<decltype(cacheMutex
)> cacheReadLock(cacheMutex
,
2728 if (t
.getContext()->isMultithreadingEnabled())
2729 cacheReadLock
.lock();
2730 auto existingIt
= cachedDirectConversions
.find(t
);
2731 if (existingIt
!= cachedDirectConversions
.end()) {
2732 if (existingIt
->second
)
2733 results
.push_back(existingIt
->second
);
2734 return success(existingIt
->second
!= nullptr);
2736 auto multiIt
= cachedMultiConversions
.find(t
);
2737 if (multiIt
!= cachedMultiConversions
.end()) {
2738 results
.append(multiIt
->second
.begin(), multiIt
->second
.end());
2742 // Walk the added converters in reverse order to apply the most recently
2743 // registered first.
2744 size_t currentCount
= results
.size();
2746 std::unique_lock
<decltype(cacheMutex
)> cacheWriteLock(cacheMutex
,
2749 for (const ConversionCallbackFn
&converter
: llvm::reverse(conversions
)) {
2750 if (std::optional
<LogicalResult
> result
= converter(t
, results
)) {
2751 if (t
.getContext()->isMultithreadingEnabled())
2752 cacheWriteLock
.lock();
2753 if (!succeeded(*result
)) {
2754 cachedDirectConversions
.try_emplace(t
, nullptr);
2757 auto newTypes
= ArrayRef
<Type
>(results
).drop_front(currentCount
);
2758 if (newTypes
.size() == 1)
2759 cachedDirectConversions
.try_emplace(t
, newTypes
.front());
2761 cachedMultiConversions
.try_emplace(t
, llvm::to_vector
<2>(newTypes
));
2768 Type
TypeConverter::convertType(Type t
) const {
2769 // Use the multi-type result version to convert the type.
2770 SmallVector
<Type
, 1> results
;
2771 if (failed(convertType(t
, results
)))
2774 // Check to ensure that only one type was produced.
2775 return results
.size() == 1 ? results
.front() : nullptr;
2779 TypeConverter::convertTypes(TypeRange types
,
2780 SmallVectorImpl
<Type
> &results
) const {
2781 for (Type type
: types
)
2782 if (failed(convertType(type
, results
)))
2787 bool TypeConverter::isLegal(Type type
) const {
2788 return convertType(type
) == type
;
2790 bool TypeConverter::isLegal(Operation
*op
) const {
2791 return isLegal(op
->getOperandTypes()) && isLegal(op
->getResultTypes());
2794 bool TypeConverter::isLegal(Region
*region
) const {
2795 return llvm::all_of(*region
, [this](Block
&block
) {
2796 return isLegal(block
.getArgumentTypes());
2800 bool TypeConverter::isSignatureLegal(FunctionType ty
) const {
2801 return isLegal(llvm::concat
<const Type
>(ty
.getInputs(), ty
.getResults()));
2805 TypeConverter::convertSignatureArg(unsigned inputNo
, Type type
,
2806 SignatureConversion
&result
) const {
2807 // Try to convert the given input type.
2808 SmallVector
<Type
, 1> convertedTypes
;
2809 if (failed(convertType(type
, convertedTypes
)))
2812 // If this argument is being dropped, there is nothing left to do.
2813 if (convertedTypes
.empty())
2816 // Otherwise, add the new inputs.
2817 result
.addInputs(inputNo
, convertedTypes
);
2821 TypeConverter::convertSignatureArgs(TypeRange types
,
2822 SignatureConversion
&result
,
2823 unsigned origInputOffset
) const {
2824 for (unsigned i
= 0, e
= types
.size(); i
!= e
; ++i
)
2825 if (failed(convertSignatureArg(origInputOffset
+ i
, types
[i
], result
)))
2830 Value
TypeConverter::materializeArgumentConversion(OpBuilder
&builder
,
2833 ValueRange inputs
) const {
2834 for (const MaterializationCallbackFn
&fn
:
2835 llvm::reverse(argumentMaterializations
))
2836 if (Value result
= fn(builder
, resultType
, inputs
, loc
))
2841 Value
TypeConverter::materializeSourceConversion(OpBuilder
&builder
,
2842 Location loc
, Type resultType
,
2843 ValueRange inputs
) const {
2844 for (const MaterializationCallbackFn
&fn
:
2845 llvm::reverse(sourceMaterializations
))
2846 if (Value result
= fn(builder
, resultType
, inputs
, loc
))
2851 Value
TypeConverter::materializeTargetConversion(OpBuilder
&builder
,
2852 Location loc
, Type resultType
,
2854 Type originalType
) const {
2855 SmallVector
<Value
> result
= materializeTargetConversion(
2856 builder
, loc
, TypeRange(resultType
), inputs
, originalType
);
2859 assert(result
.size() == 1 && "expected single result");
2860 return result
.front();
2863 SmallVector
<Value
> TypeConverter::materializeTargetConversion(
2864 OpBuilder
&builder
, Location loc
, TypeRange resultTypes
, ValueRange inputs
,
2865 Type originalType
) const {
2866 for (const TargetMaterializationCallbackFn
&fn
:
2867 llvm::reverse(targetMaterializations
)) {
2868 SmallVector
<Value
> result
=
2869 fn(builder
, resultTypes
, inputs
, loc
, originalType
);
2872 assert(TypeRange(ValueRange(result
)) == resultTypes
&&
2873 "callback produced incorrect number of values or values with "
2880 std::optional
<TypeConverter::SignatureConversion
>
2881 TypeConverter::convertBlockSignature(Block
*block
) const {
2882 SignatureConversion
conversion(block
->getNumArguments());
2883 if (failed(convertSignatureArgs(block
->getArgumentTypes(), conversion
)))
2884 return std::nullopt
;
2888 //===----------------------------------------------------------------------===//
2889 // Type attribute conversion
2890 //===----------------------------------------------------------------------===//
2891 TypeConverter::AttributeConversionResult
2892 TypeConverter::AttributeConversionResult::result(Attribute attr
) {
2893 return AttributeConversionResult(attr
, resultTag
);
2896 TypeConverter::AttributeConversionResult
2897 TypeConverter::AttributeConversionResult::na() {
2898 return AttributeConversionResult(nullptr, naTag
);
2901 TypeConverter::AttributeConversionResult
2902 TypeConverter::AttributeConversionResult::abort() {
2903 return AttributeConversionResult(nullptr, abortTag
);
2906 bool TypeConverter::AttributeConversionResult::hasResult() const {
2907 return impl
.getInt() == resultTag
;
2910 bool TypeConverter::AttributeConversionResult::isNa() const {
2911 return impl
.getInt() == naTag
;
2914 bool TypeConverter::AttributeConversionResult::isAbort() const {
2915 return impl
.getInt() == abortTag
;
2918 Attribute
TypeConverter::AttributeConversionResult::getResult() const {
2919 assert(hasResult() && "Cannot get result from N/A or abort");
2920 return impl
.getPointer();
2923 std::optional
<Attribute
>
2924 TypeConverter::convertTypeAttribute(Type type
, Attribute attr
) const {
2925 for (const TypeAttributeConversionCallbackFn
&fn
:
2926 llvm::reverse(typeAttributeConversions
)) {
2927 AttributeConversionResult res
= fn(type
, attr
);
2928 if (res
.hasResult())
2929 return res
.getResult();
2931 return std::nullopt
;
2933 return std::nullopt
;
2936 //===----------------------------------------------------------------------===//
2937 // FunctionOpInterfaceSignatureConversion
2938 //===----------------------------------------------------------------------===//
2940 static LogicalResult
convertFuncOpTypes(FunctionOpInterface funcOp
,
2941 const TypeConverter
&typeConverter
,
2942 ConversionPatternRewriter
&rewriter
) {
2943 FunctionType type
= dyn_cast
<FunctionType
>(funcOp
.getFunctionType());
2947 // Convert the original function types.
2948 TypeConverter::SignatureConversion
result(type
.getNumInputs());
2949 SmallVector
<Type
, 1> newResults
;
2950 if (failed(typeConverter
.convertSignatureArgs(type
.getInputs(), result
)) ||
2951 failed(typeConverter
.convertTypes(type
.getResults(), newResults
)) ||
2952 failed(rewriter
.convertRegionTypes(&funcOp
.getFunctionBody(),
2953 typeConverter
, &result
)))
2956 // Update the function signature in-place.
2957 auto newType
= FunctionType::get(rewriter
.getContext(),
2958 result
.getConvertedTypes(), newResults
);
2960 rewriter
.modifyOpInPlace(funcOp
, [&] { funcOp
.setType(newType
); });
2965 /// Create a default conversion pattern that rewrites the type signature of a
2966 /// FunctionOpInterface op. This only supports ops which use FunctionType to
2967 /// represent their type.
2969 struct FunctionOpInterfaceSignatureConversion
: public ConversionPattern
{
2970 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName
,
2972 const TypeConverter
&converter
)
2973 : ConversionPattern(converter
, functionLikeOpName
, /*benefit=*/1, ctx
) {}
2976 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> /*operands*/,
2977 ConversionPatternRewriter
&rewriter
) const override
{
2978 FunctionOpInterface funcOp
= cast
<FunctionOpInterface
>(op
);
2979 return convertFuncOpTypes(funcOp
, *typeConverter
, rewriter
);
2983 struct AnyFunctionOpInterfaceSignatureConversion
2984 : public OpInterfaceConversionPattern
<FunctionOpInterface
> {
2985 using OpInterfaceConversionPattern::OpInterfaceConversionPattern
;
2988 matchAndRewrite(FunctionOpInterface funcOp
, ArrayRef
<Value
> /*operands*/,
2989 ConversionPatternRewriter
&rewriter
) const override
{
2990 return convertFuncOpTypes(funcOp
, *typeConverter
, rewriter
);
2995 FailureOr
<Operation
*>
2996 mlir::convertOpResultTypes(Operation
*op
, ValueRange operands
,
2997 const TypeConverter
&converter
,
2998 ConversionPatternRewriter
&rewriter
) {
2999 assert(op
&& "Invalid op");
3000 Location loc
= op
->getLoc();
3001 if (converter
.isLegal(op
))
3002 return rewriter
.notifyMatchFailure(loc
, "op already legal");
3004 OperationState
newOp(loc
, op
->getName());
3005 newOp
.addOperands(operands
);
3007 SmallVector
<Type
> newResultTypes
;
3008 if (failed(converter
.convertTypes(op
->getResultTypes(), newResultTypes
)))
3009 return rewriter
.notifyMatchFailure(loc
, "couldn't convert return types");
3011 newOp
.addTypes(newResultTypes
);
3012 newOp
.addAttributes(op
->getAttrs());
3013 return rewriter
.create(newOp
);
3016 void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3017 StringRef functionLikeOpName
, RewritePatternSet
&patterns
,
3018 const TypeConverter
&converter
) {
3019 patterns
.add
<FunctionOpInterfaceSignatureConversion
>(
3020 functionLikeOpName
, patterns
.getContext(), converter
);
3023 void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3024 RewritePatternSet
&patterns
, const TypeConverter
&converter
) {
3025 patterns
.add
<AnyFunctionOpInterfaceSignatureConversion
>(
3026 converter
, patterns
.getContext());
3029 //===----------------------------------------------------------------------===//
3031 //===----------------------------------------------------------------------===//
3033 void ConversionTarget::setOpAction(OperationName op
,
3034 LegalizationAction action
) {
3035 legalOperations
[op
].action
= action
;
3038 void ConversionTarget::setDialectAction(ArrayRef
<StringRef
> dialectNames
,
3039 LegalizationAction action
) {
3040 for (StringRef dialect
: dialectNames
)
3041 legalDialects
[dialect
] = action
;
3044 auto ConversionTarget::getOpAction(OperationName op
) const
3045 -> std::optional
<LegalizationAction
> {
3046 std::optional
<LegalizationInfo
> info
= getOpInfo(op
);
3047 return info
? info
->action
: std::optional
<LegalizationAction
>();
3050 auto ConversionTarget::isLegal(Operation
*op
) const
3051 -> std::optional
<LegalOpDetails
> {
3052 std::optional
<LegalizationInfo
> info
= getOpInfo(op
->getName());
3054 return std::nullopt
;
3056 // Returns true if this operation instance is known to be legal.
3057 auto isOpLegal
= [&] {
3058 // Handle dynamic legality either with the provided legality function.
3059 if (info
->action
== LegalizationAction::Dynamic
) {
3060 std::optional
<bool> result
= info
->legalityFn(op
);
3065 // Otherwise, the operation is only legal if it was marked 'Legal'.
3066 return info
->action
== LegalizationAction::Legal
;
3069 return std::nullopt
;
3071 // This operation is legal, compute any additional legality information.
3072 LegalOpDetails legalityDetails
;
3073 if (info
->isRecursivelyLegal
) {
3074 auto legalityFnIt
= opRecursiveLegalityFns
.find(op
->getName());
3075 if (legalityFnIt
!= opRecursiveLegalityFns
.end()) {
3076 legalityDetails
.isRecursivelyLegal
=
3077 legalityFnIt
->second(op
).value_or(true);
3079 legalityDetails
.isRecursivelyLegal
= true;
3082 return legalityDetails
;
3085 bool ConversionTarget::isIllegal(Operation
*op
) const {
3086 std::optional
<LegalizationInfo
> info
= getOpInfo(op
->getName());
3090 if (info
->action
== LegalizationAction::Dynamic
) {
3091 std::optional
<bool> result
= info
->legalityFn(op
);
3098 return info
->action
== LegalizationAction::Illegal
;
3101 static ConversionTarget::DynamicLegalityCallbackFn
composeLegalityCallbacks(
3102 ConversionTarget::DynamicLegalityCallbackFn oldCallback
,
3103 ConversionTarget::DynamicLegalityCallbackFn newCallback
) {
3107 auto chain
= [oldCl
= std::move(oldCallback
), newCl
= std::move(newCallback
)](
3108 Operation
*op
) -> std::optional
<bool> {
3109 if (std::optional
<bool> result
= newCl(op
))
3117 void ConversionTarget::setLegalityCallback(
3118 OperationName name
, const DynamicLegalityCallbackFn
&callback
) {
3119 assert(callback
&& "expected valid legality callback");
3120 auto *infoIt
= legalOperations
.find(name
);
3121 assert(infoIt
!= legalOperations
.end() &&
3122 infoIt
->second
.action
== LegalizationAction::Dynamic
&&
3123 "expected operation to already be marked as dynamically legal");
3124 infoIt
->second
.legalityFn
=
3125 composeLegalityCallbacks(std::move(infoIt
->second
.legalityFn
), callback
);
3128 void ConversionTarget::markOpRecursivelyLegal(
3129 OperationName name
, const DynamicLegalityCallbackFn
&callback
) {
3130 auto *infoIt
= legalOperations
.find(name
);
3131 assert(infoIt
!= legalOperations
.end() &&
3132 infoIt
->second
.action
!= LegalizationAction::Illegal
&&
3133 "expected operation to already be marked as legal");
3134 infoIt
->second
.isRecursivelyLegal
= true;
3136 opRecursiveLegalityFns
[name
] = composeLegalityCallbacks(
3137 std::move(opRecursiveLegalityFns
[name
]), callback
);
3139 opRecursiveLegalityFns
.erase(name
);
3142 void ConversionTarget::setLegalityCallback(
3143 ArrayRef
<StringRef
> dialects
, const DynamicLegalityCallbackFn
&callback
) {
3144 assert(callback
&& "expected valid legality callback");
3145 for (StringRef dialect
: dialects
)
3146 dialectLegalityFns
[dialect
] = composeLegalityCallbacks(
3147 std::move(dialectLegalityFns
[dialect
]), callback
);
3150 void ConversionTarget::setLegalityCallback(
3151 const DynamicLegalityCallbackFn
&callback
) {
3152 assert(callback
&& "expected valid legality callback");
3153 unknownLegalityFn
= composeLegalityCallbacks(unknownLegalityFn
, callback
);
3156 auto ConversionTarget::getOpInfo(OperationName op
) const
3157 -> std::optional
<LegalizationInfo
> {
3158 // Check for info for this specific operation.
3159 const auto *it
= legalOperations
.find(op
);
3160 if (it
!= legalOperations
.end())
3162 // Check for info for the parent dialect.
3163 auto dialectIt
= legalDialects
.find(op
.getDialectNamespace());
3164 if (dialectIt
!= legalDialects
.end()) {
3165 DynamicLegalityCallbackFn callback
;
3166 auto dialectFn
= dialectLegalityFns
.find(op
.getDialectNamespace());
3167 if (dialectFn
!= dialectLegalityFns
.end())
3168 callback
= dialectFn
->second
;
3169 return LegalizationInfo
{dialectIt
->second
, /*isRecursivelyLegal=*/false,
3172 // Otherwise, check if we mark unknown operations as dynamic.
3173 if (unknownLegalityFn
)
3174 return LegalizationInfo
{LegalizationAction::Dynamic
,
3175 /*isRecursivelyLegal=*/false, unknownLegalityFn
};
3176 return std::nullopt
;
3179 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3180 //===----------------------------------------------------------------------===//
3181 // PDL Configuration
3182 //===----------------------------------------------------------------------===//
3184 void PDLConversionConfig::notifyRewriteBegin(PatternRewriter
&rewriter
) {
3185 auto &rewriterImpl
=
3186 static_cast<ConversionPatternRewriter
&>(rewriter
).getImpl();
3187 rewriterImpl
.currentTypeConverter
= getTypeConverter();
3190 void PDLConversionConfig::notifyRewriteEnd(PatternRewriter
&rewriter
) {
3191 auto &rewriterImpl
=
3192 static_cast<ConversionPatternRewriter
&>(rewriter
).getImpl();
3193 rewriterImpl
.currentTypeConverter
= nullptr;
3196 /// Remap the given value using the rewriter and the type converter in the
3197 /// provided config.
3198 static FailureOr
<SmallVector
<Value
>>
3199 pdllConvertValues(ConversionPatternRewriter
&rewriter
, ValueRange values
) {
3200 SmallVector
<Value
> mappedValues
;
3201 if (failed(rewriter
.getRemappedValues(values
, mappedValues
)))
3203 return std::move(mappedValues
);
3206 void mlir::registerConversionPDLFunctions(RewritePatternSet
&patterns
) {
3207 patterns
.getPDLPatterns().registerRewriteFunction(
3209 [](PatternRewriter
&rewriter
, Value value
) -> FailureOr
<Value
> {
3210 auto results
= pdllConvertValues(
3211 static_cast<ConversionPatternRewriter
&>(rewriter
), value
);
3212 if (failed(results
))
3214 return results
->front();
3216 patterns
.getPDLPatterns().registerRewriteFunction(
3217 "convertValues", [](PatternRewriter
&rewriter
, ValueRange values
) {
3218 return pdllConvertValues(
3219 static_cast<ConversionPatternRewriter
&>(rewriter
), values
);
3221 patterns
.getPDLPatterns().registerRewriteFunction(
3223 [](PatternRewriter
&rewriter
, Type type
) -> FailureOr
<Type
> {
3224 auto &rewriterImpl
=
3225 static_cast<ConversionPatternRewriter
&>(rewriter
).getImpl();
3226 if (const TypeConverter
*converter
=
3227 rewriterImpl
.currentTypeConverter
) {
3228 if (Type newType
= converter
->convertType(type
))
3234 patterns
.getPDLPatterns().registerRewriteFunction(
3236 [](PatternRewriter
&rewriter
,
3237 TypeRange types
) -> FailureOr
<SmallVector
<Type
>> {
3238 auto &rewriterImpl
=
3239 static_cast<ConversionPatternRewriter
&>(rewriter
).getImpl();
3240 const TypeConverter
*converter
= rewriterImpl
.currentTypeConverter
;
3242 return SmallVector
<Type
>(types
);
3244 SmallVector
<Type
> remappedTypes
;
3245 if (failed(converter
->convertTypes(types
, remappedTypes
)))
3247 return std::move(remappedTypes
);
3250 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3252 //===----------------------------------------------------------------------===//
3253 // Op Conversion Entry Points
3254 //===----------------------------------------------------------------------===//
3256 //===----------------------------------------------------------------------===//
3257 // Partial Conversion
3259 LogicalResult
mlir::applyPartialConversion(
3260 ArrayRef
<Operation
*> ops
, const ConversionTarget
&target
,
3261 const FrozenRewritePatternSet
&patterns
, ConversionConfig config
) {
3262 OperationConverter
opConverter(target
, patterns
, config
,
3263 OpConversionMode::Partial
);
3264 return opConverter
.convertOperations(ops
);
3267 mlir::applyPartialConversion(Operation
*op
, const ConversionTarget
&target
,
3268 const FrozenRewritePatternSet
&patterns
,
3269 ConversionConfig config
) {
3270 return applyPartialConversion(llvm::ArrayRef(op
), target
, patterns
, config
);
3273 //===----------------------------------------------------------------------===//
3276 LogicalResult
mlir::applyFullConversion(ArrayRef
<Operation
*> ops
,
3277 const ConversionTarget
&target
,
3278 const FrozenRewritePatternSet
&patterns
,
3279 ConversionConfig config
) {
3280 OperationConverter
opConverter(target
, patterns
, config
,
3281 OpConversionMode::Full
);
3282 return opConverter
.convertOperations(ops
);
3284 LogicalResult
mlir::applyFullConversion(Operation
*op
,
3285 const ConversionTarget
&target
,
3286 const FrozenRewritePatternSet
&patterns
,
3287 ConversionConfig config
) {
3288 return applyFullConversion(llvm::ArrayRef(op
), target
, patterns
, config
);
3291 //===----------------------------------------------------------------------===//
3292 // Analysis Conversion
3294 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3295 /// op is a top-level module op (which is expected to be isolated from above),
3297 static Operation
*findCommonAncestor(ArrayRef
<Operation
*> ops
) {
3298 // Check if there is a top-level operation within `ops`. If so, return that
3300 for (Operation
*op
: ops
) {
3301 if (!op
->getParentOp()) {
3303 assert(op
->hasTrait
<OpTrait::IsIsolatedFromAbove
>() &&
3304 "expected top-level op to be isolated from above");
3305 for (Operation
*other
: ops
)
3306 assert(op
->isAncestor(other
) &&
3307 "expected ops to have a common ancestor");
3313 // No top-level op. Find a common ancestor.
3314 Operation
*commonAncestor
=
3315 ops
.front()->getParentWithTrait
<OpTrait::IsIsolatedFromAbove
>();
3316 for (Operation
*op
: ops
.drop_front()) {
3317 while (!commonAncestor
->isProperAncestor(op
)) {
3319 commonAncestor
->getParentWithTrait
<OpTrait::IsIsolatedFromAbove
>();
3320 assert(commonAncestor
&&
3321 "expected to find a common isolated from above ancestor");
3325 return commonAncestor
;
3328 LogicalResult
mlir::applyAnalysisConversion(
3329 ArrayRef
<Operation
*> ops
, ConversionTarget
&target
,
3330 const FrozenRewritePatternSet
&patterns
, ConversionConfig config
) {
3332 if (config
.legalizableOps
)
3333 assert(config
.legalizableOps
->empty() && "expected empty set");
3336 // Clone closted common ancestor that is isolated from above.
3337 Operation
*commonAncestor
= findCommonAncestor(ops
);
3339 Operation
*clonedAncestor
= commonAncestor
->clone(mapping
);
3340 // Compute inverse IR mapping.
3341 DenseMap
<Operation
*, Operation
*> inverseOperationMap
;
3342 for (auto &it
: mapping
.getOperationMap())
3343 inverseOperationMap
[it
.second
] = it
.first
;
3345 // Convert the cloned operations. The original IR will remain unchanged.
3346 SmallVector
<Operation
*> opsToConvert
= llvm::map_to_vector(
3347 ops
, [&](Operation
*op
) { return mapping
.lookup(op
); });
3348 OperationConverter
opConverter(target
, patterns
, config
,
3349 OpConversionMode::Analysis
);
3350 LogicalResult status
= opConverter
.convertOperations(opsToConvert
);
3352 // Remap `legalizableOps`, so that they point to the original ops and not the
3354 if (config
.legalizableOps
) {
3355 DenseSet
<Operation
*> originalLegalizableOps
;
3356 for (Operation
*op
: *config
.legalizableOps
)
3357 originalLegalizableOps
.insert(inverseOperationMap
[op
]);
3358 *config
.legalizableOps
= std::move(originalLegalizableOps
);
3361 // Erase the cloned IR.
3362 clonedAncestor
->erase();
3367 mlir::applyAnalysisConversion(Operation
*op
, ConversionTarget
&target
,
3368 const FrozenRewritePatternSet
&patterns
,
3369 ConversionConfig config
) {
3370 return applyAnalysisConversion(llvm::ArrayRef(op
), target
, patterns
, config
);