[ELF] Refactor merge-* tests
[llvm-project.git] / mlir / lib / Conversion / ControlFlowToLLVM / ControlFlowToLLVM.cpp
blobe5c735e10703a7852b3ae2b0b13a68f7fb8a02ed
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
20 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/ADT/StringRef.h"
29 #include <functional>
31 namespace mlir {
32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
36 using namespace mlir;
38 #define PASS_NAME "convert-cf-to-llvm"
40 namespace {
41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
42 /// assertion is violated and has no effect otherwise. The failure message is
43 /// ignored by the default lowering but should be propagated by any custom
44 /// lowering.
45 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
46 explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
47 bool abortOnFailedAssert = true)
48 : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49 abortOnFailedAssert(abortOnFailedAssert) {}
51 LogicalResult
52 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 auto loc = op.getLoc();
55 auto module = op->getParentOfType<ModuleOp>();
57 // Split block at `assert` operation.
58 Block *opBlock = rewriter.getInsertionBlock();
59 auto opPosition = rewriter.getInsertionPoint();
60 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
62 // Failed block: Generate IR to print the message and call `abort`.
63 Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64 LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
65 *getTypeConverter(), /*addNewLine=*/false,
66 /*runtimeFunctionName=*/"puts");
67 if (abortOnFailedAssert) {
68 // Insert the `abort` declaration if necessary.
69 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
70 if (!abortFunc) {
71 OpBuilder::InsertionGuard guard(rewriter);
72 rewriter.setInsertionPointToStart(module.getBody());
73 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
74 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
75 "abort", abortFuncTy);
77 rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
78 rewriter.create<LLVM::UnreachableOp>(loc);
79 } else {
80 rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
83 // Generate assertion test.
84 rewriter.setInsertionPointToEnd(opBlock);
85 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
86 op, adaptor.getArg(), continuationBlock, failureBlock);
88 return success();
91 private:
92 /// If set to `false`, messages are printed but program execution continues.
93 /// This is useful for testing asserts.
94 bool abortOnFailedAssert = true;
97 /// The cf->LLVM lowerings for branching ops require that the blocks they jump
98 /// to first have updated types which should be handled by a pattern operating
99 /// on the parent op.
100 static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
101 ValueRange operands,
102 ValueRange blockArgs, Location loc,
103 llvm::StringRef messagePrefix) {
104 for (const auto &idxAndTypes :
105 llvm::enumerate(llvm::zip(blockArgs, operands))) {
106 int64_t i = idxAndTypes.index();
107 Value argValue =
108 rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
109 Type operandType = std::get<1>(idxAndTypes.value()).getType();
110 // In the case of an invalid jump, the block argument will have been
111 // remapped to an UnrealizedConversionCast. In the case of a valid jump,
112 // there might still be a no-op conversion cast with both types being equal.
113 // Consider both of these details to see if the jump would be invalid.
114 if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115 argValue.getDefiningOp())) {
116 if (op.getOperandTypes().front() != operandType) {
117 return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
118 diag << messagePrefix;
119 diag << "mismatched types from operand # " << i << " ";
120 diag << operandType;
121 diag << " not compatible with destination block argument type ";
122 diag << op.getOperandTypes().front();
123 diag << " which should be converted with the parent op.";
128 return success();
131 /// Ensure that all block types were updated and then create an LLVM::BrOp
132 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
133 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
135 LogicalResult
136 matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137 ConversionPatternRewriter &rewriter) const override {
138 if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
139 op.getSuccessor()->getArguments(),
140 op.getLoc(),
141 /*messagePrefix=*/"")))
142 return failure();
144 rewriter.replaceOpWithNewOp<LLVM::BrOp>(
145 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
146 return success();
150 /// Ensure that all block types were updated and then create an LLVM::CondBrOp
151 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
152 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
154 LogicalResult
155 matchAndRewrite(cf::CondBranchOp op,
156 typename cf::CondBranchOp::Adaptor adaptor,
157 ConversionPatternRewriter &rewriter) const override {
158 if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
159 op.getFalseDest()->getArguments(),
160 op.getLoc(), "in false case branch ")))
161 return failure();
162 if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
163 op.getTrueDest()->getArguments(),
164 op.getLoc(), "in true case branch ")))
165 return failure();
167 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
168 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
169 return success();
173 /// Ensure that all block types were updated and then create an LLVM::SwitchOp
174 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
175 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
177 LogicalResult
178 matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179 ConversionPatternRewriter &rewriter) const override {
180 if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
181 op.getDefaultDestination()->getArguments(),
182 op.getLoc(), "in switch default case ")))
183 return failure();
185 for (const auto &i : llvm::enumerate(
186 llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
187 if (failed(verifyMatchingValues(
188 rewriter, std::get<0>(i.value()),
189 std::get<1>(i.value())->getArguments(), op.getLoc(),
190 "in switch case " + std::to_string(i.index()) + " "))) {
191 return failure();
195 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
196 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
197 return success();
201 } // namespace
203 void mlir::cf::populateControlFlowToLLVMConversionPatterns(
204 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
205 // clang-format off
206 patterns.add<
207 AssertOpLowering,
208 BranchOpLowering,
209 CondBranchOpLowering,
210 SwitchOpLowering>(converter);
211 // clang-format on
214 void mlir::cf::populateAssertToLLVMConversionPattern(
215 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
216 bool abortOnFailure) {
217 patterns.add<AssertOpLowering>(converter, abortOnFailure);
220 //===----------------------------------------------------------------------===//
221 // Pass Definition
222 //===----------------------------------------------------------------------===//
224 namespace {
225 /// A pass converting MLIR operations into the LLVM IR dialect.
226 struct ConvertControlFlowToLLVM
227 : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
229 using Base::Base;
231 /// Run the dialect converter on the module.
232 void runOnOperation() override {
233 LLVMConversionTarget target(getContext());
234 RewritePatternSet patterns(&getContext());
236 LowerToLLVMOptions options(&getContext());
237 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
238 options.overrideIndexBitwidth(indexBitwidth);
240 LLVMTypeConverter converter(&getContext(), options);
241 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
243 if (failed(applyPartialConversion(getOperation(), target,
244 std::move(patterns))))
245 signalPassFailure();
248 } // namespace
250 //===----------------------------------------------------------------------===//
251 // ConvertToLLVMPatternInterface implementation
252 //===----------------------------------------------------------------------===//
254 namespace {
255 /// Implement the interface to convert MemRef to LLVM.
256 struct ControlFlowToLLVMDialectInterface
257 : public ConvertToLLVMPatternInterface {
258 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
259 void loadDependentDialects(MLIRContext *context) const final {
260 context->loadDialect<LLVM::LLVMDialect>();
263 /// Hook for derived dialect interface to provide conversion patterns
264 /// and mark dialect legal for the conversion target.
265 void populateConvertToLLVMConversionPatterns(
266 ConversionTarget &target, LLVMTypeConverter &typeConverter,
267 RewritePatternSet &patterns) const final {
268 mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
269 patterns);
272 } // namespace
274 void mlir::cf::registerConvertControlFlowToLLVMInterface(
275 DialectRegistry &registry) {
276 registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
277 dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();