1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
20 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/ADT/StringRef.h"
32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
33 #include "mlir/Conversion/Passes.h.inc"
38 #define PASS_NAME "convert-cf-to-llvm"
41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
42 /// assertion is violated and has no effect otherwise. The failure message is
43 /// ignored by the default lowering but should be propagated by any custom
45 struct AssertOpLowering
: public ConvertOpToLLVMPattern
<cf::AssertOp
> {
46 explicit AssertOpLowering(const LLVMTypeConverter
&typeConverter
,
47 bool abortOnFailedAssert
= true)
48 : ConvertOpToLLVMPattern
<cf::AssertOp
>(typeConverter
, /*benefit=*/1),
49 abortOnFailedAssert(abortOnFailedAssert
) {}
52 matchAndRewrite(cf::AssertOp op
, OpAdaptor adaptor
,
53 ConversionPatternRewriter
&rewriter
) const override
{
54 auto loc
= op
.getLoc();
55 auto module
= op
->getParentOfType
<ModuleOp
>();
57 // Split block at `assert` operation.
58 Block
*opBlock
= rewriter
.getInsertionBlock();
59 auto opPosition
= rewriter
.getInsertionPoint();
60 Block
*continuationBlock
= rewriter
.splitBlock(opBlock
, opPosition
);
62 // Failed block: Generate IR to print the message and call `abort`.
63 Block
*failureBlock
= rewriter
.createBlock(opBlock
->getParent());
64 LLVM::createPrintStrCall(rewriter
, loc
, module
, "assert_msg", op
.getMsg(),
65 *getTypeConverter(), /*addNewLine=*/false,
66 /*runtimeFunctionName=*/"puts");
67 if (abortOnFailedAssert
) {
68 // Insert the `abort` declaration if necessary.
69 auto abortFunc
= module
.lookupSymbol
<LLVM::LLVMFuncOp
>("abort");
71 OpBuilder::InsertionGuard
guard(rewriter
);
72 rewriter
.setInsertionPointToStart(module
.getBody());
73 auto abortFuncTy
= LLVM::LLVMFunctionType::get(getVoidType(), {});
74 abortFunc
= rewriter
.create
<LLVM::LLVMFuncOp
>(rewriter
.getUnknownLoc(),
75 "abort", abortFuncTy
);
77 rewriter
.create
<LLVM::CallOp
>(loc
, abortFunc
, std::nullopt
);
78 rewriter
.create
<LLVM::UnreachableOp
>(loc
);
80 rewriter
.create
<LLVM::BrOp
>(loc
, ValueRange(), continuationBlock
);
83 // Generate assertion test.
84 rewriter
.setInsertionPointToEnd(opBlock
);
85 rewriter
.replaceOpWithNewOp
<LLVM::CondBrOp
>(
86 op
, adaptor
.getArg(), continuationBlock
, failureBlock
);
92 /// If set to `false`, messages are printed but program execution continues.
93 /// This is useful for testing asserts.
94 bool abortOnFailedAssert
= true;
97 /// The cf->LLVM lowerings for branching ops require that the blocks they jump
98 /// to first have updated types which should be handled by a pattern operating
100 static LogicalResult
verifyMatchingValues(ConversionPatternRewriter
&rewriter
,
102 ValueRange blockArgs
, Location loc
,
103 llvm::StringRef messagePrefix
) {
104 for (const auto &idxAndTypes
:
105 llvm::enumerate(llvm::zip(blockArgs
, operands
))) {
106 int64_t i
= idxAndTypes
.index();
108 rewriter
.getRemappedValue(std::get
<0>(idxAndTypes
.value()));
109 Type operandType
= std::get
<1>(idxAndTypes
.value()).getType();
110 // In the case of an invalid jump, the block argument will have been
111 // remapped to an UnrealizedConversionCast. In the case of a valid jump,
112 // there might still be a no-op conversion cast with both types being equal.
113 // Consider both of these details to see if the jump would be invalid.
114 if (auto op
= dyn_cast_or_null
<UnrealizedConversionCastOp
>(
115 argValue
.getDefiningOp())) {
116 if (op
.getOperandTypes().front() != operandType
) {
117 return rewriter
.notifyMatchFailure(loc
, [&](Diagnostic
&diag
) {
118 diag
<< messagePrefix
;
119 diag
<< "mismatched types from operand # " << i
<< " ";
121 diag
<< " not compatible with destination block argument type ";
122 diag
<< op
.getOperandTypes().front();
123 diag
<< " which should be converted with the parent op.";
131 /// Ensure that all block types were updated and then create an LLVM::BrOp
132 struct BranchOpLowering
: public ConvertOpToLLVMPattern
<cf::BranchOp
> {
133 using ConvertOpToLLVMPattern
<cf::BranchOp
>::ConvertOpToLLVMPattern
;
136 matchAndRewrite(cf::BranchOp op
, typename
cf::BranchOp::Adaptor adaptor
,
137 ConversionPatternRewriter
&rewriter
) const override
{
138 if (failed(verifyMatchingValues(rewriter
, adaptor
.getDestOperands(),
139 op
.getSuccessor()->getArguments(),
141 /*messagePrefix=*/"")))
144 rewriter
.replaceOpWithNewOp
<LLVM::BrOp
>(
145 op
, adaptor
.getOperands(), op
->getSuccessors(), op
->getAttrs());
150 /// Ensure that all block types were updated and then create an LLVM::CondBrOp
151 struct CondBranchOpLowering
: public ConvertOpToLLVMPattern
<cf::CondBranchOp
> {
152 using ConvertOpToLLVMPattern
<cf::CondBranchOp
>::ConvertOpToLLVMPattern
;
155 matchAndRewrite(cf::CondBranchOp op
,
156 typename
cf::CondBranchOp::Adaptor adaptor
,
157 ConversionPatternRewriter
&rewriter
) const override
{
158 if (failed(verifyMatchingValues(rewriter
, adaptor
.getFalseDestOperands(),
159 op
.getFalseDest()->getArguments(),
160 op
.getLoc(), "in false case branch ")))
162 if (failed(verifyMatchingValues(rewriter
, adaptor
.getTrueDestOperands(),
163 op
.getTrueDest()->getArguments(),
164 op
.getLoc(), "in true case branch ")))
167 rewriter
.replaceOpWithNewOp
<LLVM::CondBrOp
>(
168 op
, adaptor
.getOperands(), op
->getSuccessors(), op
->getAttrs());
173 /// Ensure that all block types were updated and then create an LLVM::SwitchOp
174 struct SwitchOpLowering
: public ConvertOpToLLVMPattern
<cf::SwitchOp
> {
175 using ConvertOpToLLVMPattern
<cf::SwitchOp
>::ConvertOpToLLVMPattern
;
178 matchAndRewrite(cf::SwitchOp op
, typename
cf::SwitchOp::Adaptor adaptor
,
179 ConversionPatternRewriter
&rewriter
) const override
{
180 if (failed(verifyMatchingValues(rewriter
, adaptor
.getDefaultOperands(),
181 op
.getDefaultDestination()->getArguments(),
182 op
.getLoc(), "in switch default case ")))
185 for (const auto &i
: llvm::enumerate(
186 llvm::zip(adaptor
.getCaseOperands(), op
.getCaseDestinations()))) {
187 if (failed(verifyMatchingValues(
188 rewriter
, std::get
<0>(i
.value()),
189 std::get
<1>(i
.value())->getArguments(), op
.getLoc(),
190 "in switch case " + std::to_string(i
.index()) + " "))) {
195 rewriter
.replaceOpWithNewOp
<LLVM::SwitchOp
>(
196 op
, adaptor
.getOperands(), op
->getSuccessors(), op
->getAttrs());
203 void mlir::cf::populateControlFlowToLLVMConversionPatterns(
204 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
) {
209 CondBranchOpLowering
,
210 SwitchOpLowering
>(converter
);
214 void mlir::cf::populateAssertToLLVMConversionPattern(
215 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
,
216 bool abortOnFailure
) {
217 patterns
.add
<AssertOpLowering
>(converter
, abortOnFailure
);
220 //===----------------------------------------------------------------------===//
222 //===----------------------------------------------------------------------===//
225 /// A pass converting MLIR operations into the LLVM IR dialect.
226 struct ConvertControlFlowToLLVM
227 : public impl::ConvertControlFlowToLLVMPassBase
<ConvertControlFlowToLLVM
> {
231 /// Run the dialect converter on the module.
232 void runOnOperation() override
{
233 LLVMConversionTarget
target(getContext());
234 RewritePatternSet
patterns(&getContext());
236 LowerToLLVMOptions
options(&getContext());
237 if (indexBitwidth
!= kDeriveIndexBitwidthFromDataLayout
)
238 options
.overrideIndexBitwidth(indexBitwidth
);
240 LLVMTypeConverter
converter(&getContext(), options
);
241 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter
, patterns
);
243 if (failed(applyPartialConversion(getOperation(), target
,
244 std::move(patterns
))))
250 //===----------------------------------------------------------------------===//
251 // ConvertToLLVMPatternInterface implementation
252 //===----------------------------------------------------------------------===//
255 /// Implement the interface to convert MemRef to LLVM.
256 struct ControlFlowToLLVMDialectInterface
257 : public ConvertToLLVMPatternInterface
{
258 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface
;
259 void loadDependentDialects(MLIRContext
*context
) const final
{
260 context
->loadDialect
<LLVM::LLVMDialect
>();
263 /// Hook for derived dialect interface to provide conversion patterns
264 /// and mark dialect legal for the conversion target.
265 void populateConvertToLLVMConversionPatterns(
266 ConversionTarget
&target
, LLVMTypeConverter
&typeConverter
,
267 RewritePatternSet
&patterns
) const final
{
268 mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter
,
274 void mlir::cf::registerConvertControlFlowToLLVMInterface(
275 DialectRegistry
®istry
) {
276 registry
.addExtension(+[](MLIRContext
*ctx
, cf::ControlFlowDialect
*dialect
) {
277 dialect
->addInterfaces
<ControlFlowToLLVMDialectInterface
>();