From aa65473c9ddcf3cbb80e63c38af842d05346374b Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 21 Nov 2024 10:26:05 +0900 Subject: [PATCH] [mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase (#116934) The dialect conversion driver has three phases: - **Create** `IRRewrite` objects as the IR is traversed. - **Finalize** `IRRewrite` objects. During this phase, source materializations for mismatching value types are created. (E.g., when `Value` is replaced with a `Value` of different type, but there is a user of the original value that was not modified because it is already legal.) - **Commit** `IRRewrite` objects. During this phase, all remaining IR modifications are materialized. In particular, SSA values are actually being replaced during this phase. This commit removes the "finalize" phase. This simplifies the code base a bit and avoids one traversal over the `IRRewrite` stack. Source materializations are now built during the "commit" phase, right before an SSA value is being replaced. This commit also removes the "inverse mapping" of the conversion value mapping, which was used to predict if an SSA value will be dead at the end of the conversion. This check is replaced with an approximate check that does not require an inverse mapping. (A false positive for `v` can occur if another value `v2` is mapped to `v` and `v2` turns out to be dead at the end of the conversion. This case is not expected to happen very often.) This reduces the complexity of the driver a bit and removes one potential source of bugs. (There have been bugs in the usage of the inverse mapping in the past.) `BlockTypeConversionRewrite` no longer stores a pointer to the type converter. This pointer is now stored in `ReplaceBlockArgRewrite`. This commit is in preparation of merging the 1:1 and 1:N dialect conversion driver. It simplifies the upcoming changes around the conversion value mapping. (API surface of the conversion value mapping is reduced.) --- mlir/lib/Transforms/Utils/DialectConversion.cpp | 184 ++++++++++-------------- 1 file changed, 72 insertions(+), 112 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 42fe5b925654..03d483f73f25 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -75,6 +75,10 @@ namespace { /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { + /// Return "true" if an SSA value is mapped to the given value. May return + /// false positives. + bool isMappedTo(Value value) const { return mappedTo.contains(value); } + /// Lookup the most recently mapped value with the desired type in the /// mapping. /// @@ -99,22 +103,18 @@ struct ConversionValueMapping { assert(it != oldVal && "inserting cyclic mapping"); }); mapping.map(oldVal, newVal); + mappedTo.insert(newVal); } /// Drop the last mapping for the given value. void erase(Value value) { mapping.erase(value); } - /// Returns the inverse raw value mapping (without recursive query support). - DenseMap> getInverse() const { - DenseMap> inverse; - for (auto &it : mapping.getValueMap()) - inverse[it.second].push_back(it.first); - return inverse; - } - private: /// Current value mappings. IRMapping mapping; + + /// All SSA values that are mapped to. May contain false positives. + DenseSet mappedTo; }; } // namespace @@ -434,10 +434,9 @@ private: class BlockTypeConversionRewrite : public BlockRewrite { public: BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block, Block *origBlock, - const TypeConverter *converter) + Block *block, Block *origBlock) : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block), - origBlock(origBlock), converter(converter) {} + origBlock(origBlock) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::BlockTypeConversion; @@ -445,8 +444,6 @@ public: Block *getOrigBlock() const { return origBlock; } - const TypeConverter *getConverter() const { return converter; } - void commit(RewriterBase &rewriter) override; void rollback() override; @@ -454,9 +451,6 @@ public: private: /// The original block that was requested to have its signature converted. Block *origBlock; - - /// The type converter used to convert the arguments. - const TypeConverter *converter; }; /// Replacing a block argument. This rewrite is not immediately reflected in the @@ -465,8 +459,10 @@ private: class ReplaceBlockArgRewrite : public BlockRewrite { public: ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block, BlockArgument arg) - : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {} + Block *block, BlockArgument arg, + const TypeConverter *converter) + : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg), + converter(converter) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::ReplaceBlockArg; @@ -478,6 +474,9 @@ public: private: BlockArgument arg; + + /// The current type converter when the block argument was replaced. + const TypeConverter *converter; }; /// An operation rewrite. @@ -627,8 +626,6 @@ public: void cleanup(RewriterBase &rewriter) override; - const TypeConverter *getConverter() const { return converter; } - private: /// An optional type converter that can be used to materialize conversions /// between the new and old values if necessary. @@ -825,6 +822,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { ValueRange replacements, Value originalValue, const TypeConverter *converter); + /// Find a replacement value for the given SSA value in the conversion value + /// mapping. The replacement value must have the same type as the given SSA + /// value. If there is no replacement value with the correct type, find the + /// latest replacement value (regardless of the type) and build a source + /// materialization. + Value findOrBuildReplacementValue(Value value, + const TypeConverter *converter); + //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -970,7 +975,7 @@ void BlockTypeConversionRewrite::rollback() { } void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType()); + Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); if (!repl) return; @@ -999,7 +1004,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { // Compute replacement values. SmallVector replacements = llvm::map_to_vector(op->getResults(), [&](OpResult result) { - return rewriterImpl.mapping.lookupOrNull(result, result.getType()); + return rewriterImpl.findOrBuildReplacementValue(result, converter); }); // Notify the listener that the operation is about to be replaced. @@ -1069,8 +1074,10 @@ void UnresolvedMaterializationRewrite::rollback() { void ConversionPatternRewriterImpl::applyRewrites() { // Commit all rewrites. IRRewriter rewriter(context, config.listener); - for (auto &rewrite : rewrites) - rewrite->commit(rewriter); + // Note: New rewrites may be added during the "commit" phase and the + // `rewrites` vector may reallocate. + for (size_t i = 0; i < rewrites.size(); ++i) + rewrites[i]->commit(rewriter); // Clean up all rewrites. for (auto &rewrite : rewrites) @@ -1275,7 +1282,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /*inputs=*/ValueRange(), /*outputType=*/origArgType, /*originalType=*/Type(), converter); mapping.map(origArg, repl); - appendRewrite(block, origArg); + appendRewrite(block, origArg, converter); continue; } @@ -1285,7 +1292,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( "invalid to provide a replacement value when the argument isn't " "dropped"); mapping.map(origArg, repl); - appendRewrite(block, origArg); + appendRewrite(block, origArg, converter); continue; } @@ -1298,10 +1305,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( insertNTo1Materialization( OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), /*replacements=*/replArgs, /*outputValue=*/origArg, converter); - appendRewrite(block, origArg); + appendRewrite(block, origArg, converter); } - appendRewrite(newBlock, block, converter); + appendRewrite(newBlock, block); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1371,6 +1378,41 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization( } } +Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( + Value value, const TypeConverter *converter) { + // Find a replacement value with the same type. + Value repl = mapping.lookupOrNull(value, value.getType()); + if (repl) + return repl; + + // Check if the value is dead. No replacement value is needed in that case. + // This is an approximate check that may have false negatives but does not + // require computing and traversing an inverse mapping. (We may end up + // building source materializations that are never used and that fold away.) + if (llvm::all_of(value.getUsers(), + [&](Operation *op) { return replacedOps.contains(op); }) && + !mapping.isMappedTo(value)) + return Value(); + + // No replacement value was found. Get the latest replacement value + // (regardless of the type) and build a source materialization to the + // original type. + repl = mapping.lookupOrNull(value); + if (!repl) { + // No replacement value is registered in the mapping. This means that the + // value is dropped and no longer needed. (If the value were still needed, + // a source materialization producing a replacement value "out of thin air" + // would have already been created during `replaceOp` or + // `applySignatureConversion`.) + return Value(); + } + Value castValue = buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(), + /*inputs=*/repl, /*outputType=*/value.getType(), + /*originalType=*/Type(), converter); + return castValue; +} + //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1597,7 +1639,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, << "'(in region of '" << parentOp->getName() << "'(" << from.getOwner()->getParentOp() << ")\n"; }); - impl->appendRewrite(from.getOwner(), from); + impl->appendRewrite(from.getOwner(), from, + impl->currentTypeConverter); impl->mapping.map(impl->mapping.lookupOrDefault(from), to); } @@ -2417,10 +2460,6 @@ private: /// Converts an operation with the given rewriter. LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); - /// This method is called after the conversion process to legalize any - /// remaining artifacts and complete the conversion. - void finalize(ConversionPatternRewriter &rewriter); - /// Dialect conversion configuration. ConversionConfig config; @@ -2541,11 +2580,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { if (failed(convert(rewriter, op))) return rewriterImpl.undoRewrites(), failure(); - // Now that all of the operations have been converted, finalize the conversion - // process to ensure any lingering conversion artifacts are cleaned up and - // legalized. - finalize(rewriter); - // After a successful conversion, apply rewrites. rewriterImpl.applyRewrites(); @@ -2579,80 +2613,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { return success(); } -/// Finds a user of the given value, or of any other value that the given value -/// replaced, that was not replaced in the conversion process. -static Operation *findLiveUserOfReplaced( - Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, - const DenseMap> &inverseMapping) { - SmallVector worklist = {initialValue}; - while (!worklist.empty()) { - Value value = worklist.pop_back_val(); - - // Walk the users of this value to see if there are any live users that - // weren't replaced during conversion. - auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) { - return rewriterImpl.isOpIgnored(user); - }); - if (liveUserIt != value.user_end()) - return *liveUserIt; - auto mapIt = inverseMapping.find(value); - if (mapIt != inverseMapping.end()) - worklist.append(mapIt->second); - } - return nullptr; -} - -/// Helper function that returns the replaced values and the type converter if -/// the given rewrite object is an "operation replacement" or a "block type -/// conversion" (which corresponds to a "block replacement"). Otherwise, return -/// an empty ValueRange and a null type converter pointer. -static std::pair -getReplacedValues(IRRewrite *rewrite) { - if (auto *opRewrite = dyn_cast(rewrite)) - return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()}; - if (auto *blockRewrite = dyn_cast(rewrite)) - return {blockRewrite->getOrigBlock()->getArguments(), - blockRewrite->getConverter()}; - return {}; -} - -void OperationConverter::finalize(ConversionPatternRewriter &rewriter) { - ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - DenseMap> inverseMapping = - rewriterImpl.mapping.getInverse(); - - // Process requested value replacements. - for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) { - ValueRange replacedValues; - const TypeConverter *converter; - std::tie(replacedValues, converter) = - getReplacedValues(rewriterImpl.rewrites[i].get()); - for (Value originalValue : replacedValues) { - // If the type of this value changed and the value is still live, we need - // to materialize a conversion. - if (rewriterImpl.mapping.lookupOrNull(originalValue, - originalValue.getType())) - continue; - Operation *liveUser = - findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping); - if (!liveUser) - continue; - - // Legalize this value replacement. - Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue); - assert(newValue && "replacement value not found"); - Value castValue = rewriterImpl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(newValue), - originalValue.getLoc(), - /*inputs=*/newValue, /*outputType=*/originalValue.getType(), - /*originalType=*/Type(), converter); - rewriterImpl.mapping.map(originalValue, castValue); - inverseMapping[castValue].push_back(originalValue); - llvm::erase(inverseMapping[newValue], originalValue); - } - } -} - //===----------------------------------------------------------------------===// // Reconcile Unrealized Casts //===----------------------------------------------------------------------===// -- 2.11.4.GIT